586 lines
24 KiB
Python
586 lines
24 KiB
Python
import json
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
import faiss
|
||
from typing import List, Tuple
|
||
from sentence_transformers import SentenceTransformer
|
||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import datetime
|
||
import re
|
||
import random
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from embedder_sbert_qavanin_285k import PersianVectorAnalyzer
|
||
#from normalizer import cleaning
|
||
from fastapi import FastAPI ,Header
|
||
from pydantic import BaseModel
|
||
# LLM Libs
|
||
from openai import OpenAI
|
||
from langchain_openai import ChatOpenAI # pip install -U langchain_openai
|
||
import requests
|
||
|
||
today = f'{datetime.datetime.now().year}{datetime.datetime.now().month}{datetime.datetime.now().day}'
|
||
|
||
chatbot = FastAPI()
|
||
origins = ["*"]
|
||
|
||
chatbot.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# -------------------
|
||
# مدلها و مسیر دادهsrc/app/qavanin-faiss/faiss_index_qavanin_285k_metadata.json
|
||
# -------------------/src/app/qavanin-faiss
|
||
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
||
FAISS_INDEX_PATH = "/src/app/qavanin-faiss/faiss_index_qavanin_285k.index"
|
||
FAISS_METADATA_PATH = "/src/app/qavanin-faiss/faiss_index_qavanin_285k_metadata.json"
|
||
|
||
RERANK_BATCH = int(os.environ.get("RERANK_BATCH", 256))
|
||
# print(f'RERANK_BATCH: {RERANK_BATCH}')
|
||
|
||
def get_key():
|
||
key = 'aa-fdh9d847ANcBxQCBTZD5hrrAdl0UrPEnJOScYmOncrkagYPf'
|
||
return key
|
||
|
||
def load_faiss_index(index_path: str, metadata_path: str):
|
||
"""بارگذاری ایندکس FAISS و متادیتا (لیست جملات + عناوین)."""
|
||
index = faiss.read_index(index_path)
|
||
|
||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||
metadata = json.load(f)
|
||
|
||
content_list, ids, prefix_list = [], [], []
|
||
for item in metadata:
|
||
content_list.append(item["content"])
|
||
ids.append(item["id"])
|
||
prefix_list.append(item["prefix"])
|
||
|
||
return content_list, ids, prefix_list, index
|
||
|
||
def get_client():
|
||
url = "https://api.avalai.ir/v1"
|
||
# key = 'aa-4tvAEazUBovEN1i7i7tdl1PR93OaWXs6hMflR4oQbIIA4K7Z'
|
||
|
||
|
||
client = OpenAI(
|
||
api_key= get_key(), # با کلید واقعی خود جایگزین کنید
|
||
base_url= url, # آدرس پایه
|
||
)
|
||
|
||
return client
|
||
|
||
def llm_base_request(query):
|
||
# model = 'cf.gemma-3-12b-it'
|
||
model = 'gpt-4o-mini'
|
||
prompt = f'برای متن {query} زیر، عنوانی کوتاه که بین 3 تا 6 کلمه داشته باشد، انتخاب کن. غیر از عنوان، به هیچ وجه توضیح اضافه ای در قبل یا بعد آن اضافه نکن.'
|
||
client = get_client()
|
||
try:
|
||
messages.append({"role": "user", "content": prompt})
|
||
response = client.chat.completions.create(
|
||
messages = messages,
|
||
model= model) # "gpt-4o", "gpt-4o-mini", "deepseek-chat" , "gemini-2.0-flash", gemini-2.5-flash-lite
|
||
# gpt-4o : 500
|
||
# gpt-4o-mini : 34
|
||
# deepseek-chat: : 150
|
||
# gemini-2.0-flash : error
|
||
# cf.gemma-3-12b-it : 1
|
||
# gemini-2.5-flash-lite : 35 خیلی خوب
|
||
|
||
answer = response.choices[0].message.content
|
||
# پاسخ را هم به سابقه اضافه میکنیم
|
||
messages.append({"role": "assistant", "content": answer})
|
||
|
||
|
||
except Exception as error:
|
||
with open('./llm-answer/error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
|
||
file.write(error_message)
|
||
|
||
return ''
|
||
|
||
return answer
|
||
|
||
def llm_request(query, model):
|
||
|
||
if query == '':
|
||
return 'لطفا متن سوال را وارد نمائید'
|
||
|
||
client = get_client()
|
||
determine_refrence = """شناسه هر ماده قانون در ابتدای آن و با فرمت "id: {idvalue}" آمده است که id-value همان شناسه ماده است. بازای هربخش از پاسخی که تولید می شود، ضروری است شناسه ماده ای که در تدوین پاسخ از آن استفاده شده در انتهای پاراگراف یا جمله مربوطه با فرمت {idvalue} اضافه شود. همیشه idvalue با رشته "qs" شروع می شود"""
|
||
try:
|
||
messages.append({"role": "user", "content": query})
|
||
messages.append({"role": "user", "content": determine_refrence})
|
||
response = client.chat.completions.create(
|
||
messages = messages,
|
||
model= model) # "gpt-4o", "gpt-4o-mini", "deepseek-chat" , "gemini-2.0-flash", gemini-2.5-flash-lite
|
||
# gpt-4o : 500
|
||
# gpt-4o-mini : 34
|
||
# deepseek-chat: : 150
|
||
# gemini-2.0-flash : error
|
||
# cf.gemma-3-12b-it : 1
|
||
# gemini-2.5-flash-lite : 35 خیلی خوب
|
||
|
||
answer = response.choices[0].message.content
|
||
# پاسخ را هم به سابقه اضافه میکنیم
|
||
messages.append({"role": "assistant", "content": answer})
|
||
|
||
|
||
except Exception as error:
|
||
with open('./llm-answer/error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
|
||
file.write(error_message)
|
||
|
||
return 'با عرض پوزش؛ متاسفانه خطایی رخ داده است. لطفا لحظاتی دیگر دوباره تلاش نمائید'
|
||
|
||
return answer
|
||
|
||
class HybridRetrieverReranker:
|
||
__slots__ = (
|
||
"device", "content_list", "ids", "prefix_list", "N", "embedder", "faiss_index",
|
||
"vectorizer", "tfidf_matrix", "tokenizer", "reranker", "dense_alpha"
|
||
)
|
||
|
||
def __init__(self, content_list: List[str],ids: List[str], prefix_list: List[str], faiss_index,
|
||
dense_alpha: float = 0.6, device: str = None):
|
||
|
||
if device is None:
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
self.device = device
|
||
|
||
self.content_list = content_list
|
||
self.ids = ids
|
||
self.prefix_list = prefix_list
|
||
self.faiss_index = faiss_index
|
||
self.N = len(content_list)
|
||
|
||
# Dense
|
||
self.embedder = SentenceTransformer(EMBED_MODEL,cache_folder='/src/MODELS', device=self.device)
|
||
#self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device)
|
||
|
||
# Sparse (مثل قبل برای حفظ خروجی)
|
||
self.vectorizer = TfidfVectorizer(
|
||
analyzer="word",
|
||
ngram_range=(1, 2),
|
||
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
|
||
)
|
||
self.tfidf_matrix = self.vectorizer.fit_transform(self.content_list)
|
||
|
||
# Reranker
|
||
self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL,cache_dir='/src/MODELS', use_fast=True)
|
||
self.reranker = AutoModelForSequenceClassification.from_pretrained(
|
||
RERANKER_MODEL
|
||
).to(self.device)
|
||
# self.reranker = AutoModelForSeq2SeqLM.from_pretrained(RERANKER_MODEL).to(self.device)
|
||
# self.reranker.eval()
|
||
|
||
self.dense_alpha = float(dense_alpha)
|
||
|
||
# --- Dense (FAISS) ---
|
||
def dense_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([], dtype=np.float32)
|
||
|
||
q_emb = self.embedder.encode(query, convert_to_numpy=True).astype(np.float32)
|
||
D, I = self.faiss_index.search(np.expand_dims(q_emb, axis=0), top_k)
|
||
|
||
return I[0].tolist(), D[0]
|
||
|
||
# --- Sparse ---
|
||
def sparse_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([], dtype=np.float32)
|
||
k = min(top_k, self.N)
|
||
q_vec = self.vectorizer.transform([query])
|
||
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
|
||
idx = np.argpartition(-sims, kth=k-1)[:k]
|
||
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
|
||
return idx.tolist(), sims[idx]
|
||
|
||
# --- Utils ---
|
||
@staticmethod
|
||
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
|
||
if arr.size == 0:
|
||
return arr
|
||
a_min = arr.min()
|
||
a_max = arr.max()
|
||
rng = a_max - a_min
|
||
if rng < 1e-12:
|
||
return np.zeros_like(arr)
|
||
return (arr - a_min) / rng
|
||
|
||
def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
|
||
"""
|
||
ادغام نتایج دو retriever (dense و sparse) با استفاده از Reciprocal Rank Fusion (RRF)
|
||
|
||
Args:
|
||
d_idx (list or np.ndarray): ایندکسهای نتایج dense retriever
|
||
d_scores (list or np.ndarray): نمرات dense retriever
|
||
s_idx (list or np.ndarray): ایندکسهای نتایج sparse retriever
|
||
s_scores (list or np.ndarray): نمرات sparse retriever
|
||
top_k (int): تعداد نتایج نهایی
|
||
k_rrf (int): ثابت در فرمول RRF برای کاهش تأثیر رتبههای پایینتر
|
||
|
||
Returns:
|
||
list: لیست ایندکسهای ادغامشده به ترتیب نمره
|
||
"""
|
||
combined = {}
|
||
|
||
# dense retriever
|
||
for rank, idx in enumerate(d_idx):
|
||
score = 1.0 / (k_rrf + rank)
|
||
combined[idx] = combined.get(idx, 0) + score
|
||
|
||
# sparse retriever
|
||
for rank, idx in enumerate(s_idx):
|
||
score = 1.0 / (k_rrf + rank)
|
||
combined[idx] = combined.get(idx, 0) + score
|
||
|
||
# مرتبسازی نهایی
|
||
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
|
||
cand_idx = [item[0] for item in sorted_items[:top_k]]
|
||
|
||
return cand_idx
|
||
|
||
def rerank(self, query: str, candidate_indices: List[int], passages: List[str], final_k: int) -> List[Tuple[int, float]]:
|
||
"""
|
||
Rerank candidate passages using a cross-encoder (e.g., MonoT5, MiniLM, etc.).
|
||
|
||
Args:
|
||
query (str): پرسش کاربر
|
||
candidate_indices (List[int]): ایندکسهای کاندیدا (از retriever)
|
||
passages (List[str]): کل جملات/پاراگرافها
|
||
final_k (int): تعداد نتایج نهایی
|
||
|
||
Returns:
|
||
List[Tuple[int, float]]: لیستی از (ایندکس، امتیاز) برای بهترین نتایج
|
||
"""
|
||
if final_k <= 0 or not candidate_indices:
|
||
return []
|
||
|
||
# آمادهسازی جفتهای (query, passage)
|
||
texts = [query] * len(candidate_indices)
|
||
pairs = passages
|
||
|
||
scores: List[float] = []
|
||
|
||
def _iter_batches(max_bs: int):
|
||
bs = max_bs
|
||
while bs >= 16: # حداقل batch_size
|
||
try:
|
||
with torch.inference_mode():
|
||
for start in range(0, len(pairs), bs):
|
||
batch_texts = texts[start:start + bs]
|
||
batch_pairs = pairs[start:start + bs]
|
||
inputs = self.tokenizer(
|
||
batch_texts,
|
||
batch_pairs,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512,
|
||
return_tensors="pt",
|
||
).to(self.device)
|
||
|
||
logits = self.reranker(**inputs).logits.view(-1)
|
||
scores.extend(logits.detach().cpu().tolist())
|
||
return True
|
||
except torch.cuda.OutOfMemoryError:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
bs //= 2
|
||
return False
|
||
|
||
# اجرای reranking
|
||
success = _iter_batches(max_bs=64)
|
||
if not success:
|
||
raise RuntimeError("Reranker failed due to CUDA OOM, even with small batch size.")
|
||
|
||
# مرتبسازی نتایج بر اساس نمره
|
||
reranked = sorted(
|
||
zip(candidate_indices, scores),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:final_k]
|
||
|
||
return reranked
|
||
|
||
def get_passages(self, cand_idx, content_list):
|
||
passages = []
|
||
for idx in cand_idx:
|
||
passages.append(content_list[idx])
|
||
|
||
return passages
|
||
|
||
# --- Search (بدون تغییر) ---
|
||
def search(self, query: str, content_list, topk_dense=50, topk_sparse=50,
|
||
pre_rerank_k=50, final_k=10):
|
||
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
|
||
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
|
||
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
|
||
passages = self.get_passages(cand_idx, content_list)
|
||
reranked = self.rerank(query, cand_idx, passages, final_k)
|
||
|
||
return [{"idx": i, "content": self.content_list[i],"prefix": self.prefix_list[i], "rerank_score": score}
|
||
for i, score in reranked]
|
||
|
||
def single_query(query: str):
|
||
|
||
# query = cleaning(query)
|
||
retrived_sections_ids = []
|
||
retrived_sections = pipe.search(query, content_list, topk_dense=30, topk_sparse=30, pre_rerank_k=30, final_k=10)
|
||
final_similars = ''
|
||
for i, row in enumerate(retrived_sections, 1):
|
||
id_value = '{' + str(ids[row['idx']]) + '}'
|
||
result = f"id: {id_value} \n{row['prefix']} {row['content']}\n"
|
||
retrived_sections_ids.append(ids[row['idx']])
|
||
final_similars += ''.join(result)
|
||
|
||
return final_similars, retrived_sections_ids
|
||
|
||
def find_refrences(llm_answer: str) -> List[str]:
|
||
"""
|
||
شناسایی شناسه هایی که مدل زبانی، برای تهیه پاسخ از آنها استفاده کرده است
|
||
|
||
Args:
|
||
llm_answer(str): متنی که مدل زبانی تولید کرده است
|
||
|
||
Returns:
|
||
refrence_ids(List[str]): لیستی از شناسه های تشخیص داده شده
|
||
"""
|
||
pattern = r"\{[^\}]+\}"
|
||
refrence_ids = re.findall(pattern, llm_answer)
|
||
new_refrences_ids = []
|
||
for itm in refrence_ids:
|
||
refrence = itm.lstrip('{')
|
||
refrence = refrence.lstrip('}')
|
||
new_refrences_ids.append(refrence)
|
||
# refrence_ids = [item.lstrip('{').rstrip('}') for item in refrence_ids]
|
||
return refrence_ids
|
||
|
||
def replace_refrences(llm_answer: str, refrences_list:List[str]) -> List[str]:
|
||
"""
|
||
شناسایی شناسه هایی که مدل زبانی، برای تهیه پاسخ از آنها استفاده کرده است
|
||
|
||
Args:
|
||
llm_answer(str): متنی که مدل زبانی تولید کرده است
|
||
refrences_list(List[str]): لیست شناسه ماده های مورد استفاده در پاسخ مدل زبانی
|
||
Returns:
|
||
llm_answer(str), : متن بازسازی شده پاسخ مدل زبانی که شناسه ماده های مورد استفاده در آن، اصلاح شده است
|
||
"""
|
||
refrences = ''
|
||
for index, ref in enumerate(refrences_list,1):
|
||
# breakpoint()
|
||
llm_answer = llm_answer.replace(ref, f'[{index}]')
|
||
# id = ref.lstrip('{')
|
||
# id = id.rstrip('}')
|
||
# refrences += ''.join(f'[{index}] https://majles.tavasi.ir/entity/detail/view/qsection/{id}\n')
|
||
|
||
# llm_answer = f'{llm_answer}\n\nمنابع پاسخ:\n{refrences.strip()}'
|
||
return llm_answer.strip()
|
||
|
||
# load basic items
|
||
content_list, ids, prefix_list, faiss_index = load_faiss_index(FAISS_INDEX_PATH, FAISS_METADATA_PATH)
|
||
pipe = HybridRetrieverReranker(content_list, ids, prefix_list, faiss_index, dense_alpha=0.6)
|
||
# query preprocess and normalize
|
||
normalizer_obj = PersianVectorAnalyzer()
|
||
|
||
messages = [
|
||
{"role": "system", "content": "تو یک دستیار خبره در زمینه حقوق و قوانین مرتبط به آن هستی و می توانی متون حقوقی را به صورت دقیق توضیح بدهی . پاسخ ها باید الزاما به زبان فارسی باشد. پاسخ ها فقط از متون قانونی که در پرامپت وجود دارد استخراج شود."},
|
||
]
|
||
|
||
models = ["gemini-2.5-flash-lite", "gpt-4o-mini"]
|
||
|
||
def save_result(chat_obj: object) -> bool:
|
||
# index result in elastic
|
||
pass
|
||
|
||
def run_chatbot(query:str, chat_id:str):
|
||
prompt_status = True
|
||
status_text = 'لطفا متن سوال را وارد نمائید'
|
||
if query == '':
|
||
prompt_status = False
|
||
|
||
start_time = (datetime.datetime.now())
|
||
|
||
# در صورتی که وضعیت پرامپت معتبر باشد، وارد فرایند شو
|
||
if prompt_status:
|
||
result_passages_text, result_passages_ids = single_query(query)
|
||
end_retrive = datetime.datetime.now()
|
||
print('-'*40)
|
||
retrive_duration = (end_retrive - start_time).total_seconds()
|
||
print(f'retrive duration: {str(retrive_duration)}')
|
||
|
||
prompt = f'برای پرسش "{query}" از میان مواد قانونی "{result_passages_text}" .پاسخ مناسب و دقیق را استخراج کن. درصورتی که مطلبی مرتبط با پرسش در متن پیدا نشد، فقط پاسخ بده: "متاسفانه در منابع، پاسخی پیدا نشد!"'
|
||
|
||
llm_model = ''
|
||
for model in models:
|
||
try:
|
||
llm_model = model
|
||
llm_answer = llm_request(prompt, model)
|
||
except Exception as error:
|
||
error = f'model: {model} \n{error}\n\n'
|
||
prompt_status = False
|
||
status_text = 'با عرض پوزش، سرویس موقتا در دسترس نیست. لطفا دقایقی دیگر دوباره تلاش نمائید!'
|
||
|
||
else:
|
||
chat_obj = {
|
||
'id' : chat_id, # str
|
||
'title' : '', # str
|
||
'user_id' : '',
|
||
'user_query' : query, # str
|
||
'model_key' : llm_model, # str
|
||
'retrived_passage' : result_passages_text, # str
|
||
'retrived_ref_ids' : result_passages_ids, # list[obj]
|
||
'prompt_type' : 'question-answer', # str
|
||
'retrived_duration' : retrive_duration, # str
|
||
'llm_duration' : '0', # str
|
||
'full_duration' : '0', # str
|
||
'time_create' : str(start_time), # str
|
||
'used_ref_ids' : [], # list[str]
|
||
'prompt_answer' : '', # str
|
||
'status_text' : status_text,
|
||
'status' : prompt_status, # or False # bool
|
||
}
|
||
|
||
# آبجکت ایجاد شده با بازگردان
|
||
return chat_obj, status_text
|
||
|
||
llm_answer_duration = (datetime.datetime.now() - end_retrive).total_seconds()
|
||
print(f'llm answer duration: {str(llm_answer_duration)}')
|
||
|
||
used_refrences_in_answer = find_refrences(llm_answer)
|
||
llm_answer = replace_refrences(llm_answer, used_refrences_in_answer)
|
||
|
||
full_prompt_duration = (datetime.datetime.now() - start_time).total_seconds()
|
||
print(f'full prompt duration: {full_prompt_duration}')
|
||
print('~'*40)
|
||
|
||
status_text ='پاسخ با موفقیت ایجاد شد'
|
||
|
||
title = llm_base_request(query)
|
||
if title == '':
|
||
title = query[0:15]
|
||
|
||
chat_obj = {
|
||
'id' : chat_id, # str
|
||
'title' : title, # str
|
||
'user_id' : '',
|
||
'user_query' : query, # str
|
||
'model_key' : llm_model, # str
|
||
'retrived_passage' : result_passages_text, # str
|
||
'retrived_ref_ids' : result_passages_ids, # list[obj]
|
||
'prompt_type' : 'question-answer', # str
|
||
'retrived_duration' : retrive_duration, # str
|
||
'llm_duration' : llm_answer_duration, # str
|
||
'full_duration' : full_prompt_duration, # str
|
||
'time_create' : str(start_time), # str
|
||
'used_ref_ids' : used_refrences_in_answer, # list[str]
|
||
'prompt_answer' : llm_answer, # str
|
||
'status_text' : status_text, # str
|
||
'status' : True, # or False # bool
|
||
}
|
||
prev_chat_data = []
|
||
with open('./llm-answer/chat-messages.json', mode='r', encoding='utf-8') as file:
|
||
prev_chat_data = json.load(file)
|
||
prev_chat_data.append(chat_obj)
|
||
|
||
with open('./llm-answer/chat-messages.json', mode='w', encoding='utf-8') as output:
|
||
json.dump(prev_chat_data, output, ensure_ascii=False, indent=2)
|
||
|
||
# save_result(chat_obj)
|
||
|
||
# ایجاد آبجکت بازگشتی به فرانت
|
||
# chat_obj.pop('retrived_passage')
|
||
# chat_obj.pop('prompt_type')
|
||
|
||
return chat_obj
|
||
|
||
@chatbot.post("/credit_refresh")
|
||
def credit_refresh():
|
||
"""
|
||
Returns remained credit
|
||
"""
|
||
url = "https://api.avalai.ir/user/credit"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {get_key()}"
|
||
}
|
||
remained_credit = requests.get(url, headers=headers)
|
||
|
||
with open('./llm-answer/credit.txt','w') as file:
|
||
file.write(str(remained_credit.json()['remaining_irt']))
|
||
return str(remained_credit.json()['remaining_irt'])
|
||
|
||
def create_chat_id():
|
||
date = str((datetime.datetime.now())).replace(' ','-').replace(':','').replace('.','-')
|
||
print('date ', date )
|
||
chat_id = f'{date}-{random.randint(100000, 999999)}'
|
||
print('chat_id ', chat_id )
|
||
return chat_id
|
||
|
||
print('#'*19)
|
||
print('-Chatbot is Ready!!!!!-')
|
||
print('#'*19)
|
||
|
||
# تعریف مدل دادهها برای درخواستهای API
|
||
class Query(BaseModel):
|
||
query: str
|
||
# مسیر API برای اجرا کردن run_chatbot
|
||
@chatbot.post("/run_chatbot")
|
||
def run_chat(query: Query):
|
||
print('query ', query )
|
||
chat_id = create_chat_id()
|
||
print('query.query ', query.query )
|
||
answer = run_chatbot(query.query, chat_id)
|
||
credit_refresh()
|
||
|
||
return {"answer": answer}
|
||
|
||
# uvicorn src.app:app --reload
|
||
|
||
if __name__ == "__main__":
|
||
|
||
# query = 'در قانون حمایت از خانواده و جوانی جمعیت چه خدماتی در نظر گرفته شده است؟'
|
||
while True:
|
||
query = input('enter your qustion:')
|
||
if query == '':
|
||
print('لطفا متن سوال را وارد نمائید')
|
||
continue
|
||
start = (datetime.datetime.now())
|
||
# result = test_dataset()
|
||
result = single_query(query)
|
||
end_retrive = datetime.datetime.now()
|
||
print('-'*40)
|
||
print(f'retrive duration: {(end_retrive - start).total_seconds()}')
|
||
|
||
prompt = f'برای پرسش "{query}" از میان مواد قانونی "{result}" .پاسخ مناسب و دقیق را استخراج کن. درصورتی که مطلبی مرتبط با پرسش در متن پیدا نشد، فقط پاسخ بده: "متاسفانه در منابع، پاسخی پیدا نشد!"'
|
||
llm_answer = llm_request(prompt)
|
||
|
||
print('-'*40)
|
||
print(f'llm duration: {(datetime.datetime.now() - end_retrive).total_seconds()}')
|
||
|
||
refrences = ''
|
||
recognized_refrences = find_refrences(llm_answer)
|
||
llm_answer = replace_refrences(llm_answer, recognized_refrences)
|
||
|
||
with open('./llm-answer/result.txt', mode='a+', encoding='utf-8') as file:
|
||
result_message = f'متن پرامپت: {query.strip()}\n\nپاسخ: {llm_answer} \n----------------------------------------------------------\n'
|
||
file.write(result_message)
|
||
|
||
with open('./llm-answer/passages.txt', mode='a+', encoding='utf-8') as file:
|
||
result_message = f'متن پرامپت: {query.strip()}\n\مواد مشابه: {result} \n----------------------------------------------------------\n'
|
||
file.write(result_message)
|
||
|
||
|
||
|
||
print('----------------------------------------------------------')
|
||
print(f'full duration: {(datetime.datetime.now() - start).total_seconds()}')
|
||
print('----------------------------------------------------------')
|
||
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
|
||
|