rag_qavanin_api/chatbot.py
2025-09-28 09:24:27 +00:00

586 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import numpy as np
import torch
import faiss
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import datetime
import re
import random
from fastapi.middleware.cors import CORSMiddleware
from embedder_sbert_qavanin_285k import PersianVectorAnalyzer
#from normalizer import cleaning
from fastapi import FastAPI ,Header
from pydantic import BaseModel
# LLM Libs
from openai import OpenAI
from langchain_openai import ChatOpenAI # pip install -U langchain_openai
import requests
today = f'{datetime.datetime.now().year}{datetime.datetime.now().month}{datetime.datetime.now().day}'
chatbot = FastAPI()
origins = ["*"]
chatbot.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------
# مدل‌ها و مسیر دادهsrc/app/qavanin-faiss/faiss_index_qavanin_285k_metadata.json
# -------------------/src/app/qavanin-faiss
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
FAISS_INDEX_PATH = "/src/app/qavanin-faiss/faiss_index_qavanin_285k.index"
FAISS_METADATA_PATH = "/src/app/qavanin-faiss/faiss_index_qavanin_285k_metadata.json"
RERANK_BATCH = int(os.environ.get("RERANK_BATCH", 256))
# print(f'RERANK_BATCH: {RERANK_BATCH}')
def get_key():
key = 'aa-fdh9d847ANcBxQCBTZD5hrrAdl0UrPEnJOScYmOncrkagYPf'
return key
def load_faiss_index(index_path: str, metadata_path: str):
"""بارگذاری ایندکس FAISS و متادیتا (لیست جملات + عناوین)."""
index = faiss.read_index(index_path)
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
content_list, ids, prefix_list = [], [], []
for item in metadata:
content_list.append(item["content"])
ids.append(item["id"])
prefix_list.append(item["prefix"])
return content_list, ids, prefix_list, index
def get_client():
url = "https://api.avalai.ir/v1"
# key = 'aa-4tvAEazUBovEN1i7i7tdl1PR93OaWXs6hMflR4oQbIIA4K7Z'
client = OpenAI(
api_key= get_key(), # با کلید واقعی خود جایگزین کنید
base_url= url, # آدرس پایه
)
return client
def llm_base_request(query):
# model = 'cf.gemma-3-12b-it'
model = 'gpt-4o-mini'
prompt = f'برای متن {query} زیر، عنوانی کوتاه که بین 3 تا 6 کلمه داشته باشد، انتخاب کن. غیر از عنوان، به هیچ وجه توضیح اضافه ای در قبل یا بعد آن اضافه نکن.'
client = get_client()
try:
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
messages = messages,
model= model) # "gpt-4o", "gpt-4o-mini", "deepseek-chat" , "gemini-2.0-flash", gemini-2.5-flash-lite
# gpt-4o : 500
# gpt-4o-mini : 34
# deepseek-chat: : 150
# gemini-2.0-flash : error
# cf.gemma-3-12b-it : 1
# gemini-2.5-flash-lite : 35 خیلی خوب
answer = response.choices[0].message.content
# پاسخ را هم به سابقه اضافه می‌کنیم
messages.append({"role": "assistant", "content": answer})
except Exception as error:
with open('./llm-answer/error-in-llm.txt', mode='a+', encoding='utf-8') as file:
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
file.write(error_message)
return ''
return answer
def llm_request(query, model):
if query == '':
return 'لطفا متن سوال را وارد نمائید'
client = get_client()
determine_refrence = """شناسه هر ماده قانون در ابتدای آن و با فرمت "id: {idvalue}" آمده است که id-value همان شناسه ماده است. بازای هربخش از پاسخی که تولید می شود، ضروری است شناسه ماده ای که در تدوین پاسخ از آن استفاده شده در انتهای پاراگراف یا جمله مربوطه با فرمت {idvalue} اضافه شود. همیشه idvalue با رشته "qs" شروع می شود"""
try:
messages.append({"role": "user", "content": query})
messages.append({"role": "user", "content": determine_refrence})
response = client.chat.completions.create(
messages = messages,
model= model) # "gpt-4o", "gpt-4o-mini", "deepseek-chat" , "gemini-2.0-flash", gemini-2.5-flash-lite
# gpt-4o : 500
# gpt-4o-mini : 34
# deepseek-chat: : 150
# gemini-2.0-flash : error
# cf.gemma-3-12b-it : 1
# gemini-2.5-flash-lite : 35 خیلی خوب
answer = response.choices[0].message.content
# پاسخ را هم به سابقه اضافه می‌کنیم
messages.append({"role": "assistant", "content": answer})
except Exception as error:
with open('./llm-answer/error-in-llm.txt', mode='a+', encoding='utf-8') as file:
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
file.write(error_message)
return 'با عرض پوزش؛ متاسفانه خطایی رخ داده است. لطفا لحظاتی دیگر دوباره تلاش نمائید'
return answer
class HybridRetrieverReranker:
__slots__ = (
"device", "content_list", "ids", "prefix_list", "N", "embedder", "faiss_index",
"vectorizer", "tfidf_matrix", "tokenizer", "reranker", "dense_alpha"
)
def __init__(self, content_list: List[str],ids: List[str], prefix_list: List[str], faiss_index,
dense_alpha: float = 0.6, device: str = None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.content_list = content_list
self.ids = ids
self.prefix_list = prefix_list
self.faiss_index = faiss_index
self.N = len(content_list)
# Dense
self.embedder = SentenceTransformer(EMBED_MODEL,cache_folder='/src/MODELS', device=self.device)
#self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device)
# Sparse (مثل قبل برای حفظ خروجی)
self.vectorizer = TfidfVectorizer(
analyzer="word",
ngram_range=(1, 2),
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
)
self.tfidf_matrix = self.vectorizer.fit_transform(self.content_list)
# Reranker
self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL,cache_dir='/src/MODELS', use_fast=True)
self.reranker = AutoModelForSequenceClassification.from_pretrained(
RERANKER_MODEL
).to(self.device)
# self.reranker = AutoModelForSeq2SeqLM.from_pretrained(RERANKER_MODEL).to(self.device)
# self.reranker.eval()
self.dense_alpha = float(dense_alpha)
# --- Dense (FAISS) ---
def dense_retrieve(self, query: str, top_k: int):
if top_k <= 0:
return [], np.array([], dtype=np.float32)
q_emb = self.embedder.encode(query, convert_to_numpy=True).astype(np.float32)
D, I = self.faiss_index.search(np.expand_dims(q_emb, axis=0), top_k)
return I[0].tolist(), D[0]
# --- Sparse ---
def sparse_retrieve(self, query: str, top_k: int):
if top_k <= 0:
return [], np.array([], dtype=np.float32)
k = min(top_k, self.N)
q_vec = self.vectorizer.transform([query])
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
idx = np.argpartition(-sims, kth=k-1)[:k]
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
return idx.tolist(), sims[idx]
# --- Utils ---
@staticmethod
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
if arr.size == 0:
return arr
a_min = arr.min()
a_max = arr.max()
rng = a_max - a_min
if rng < 1e-12:
return np.zeros_like(arr)
return (arr - a_min) / rng
def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
"""
ادغام نتایج دو retriever (dense و sparse) با استفاده از Reciprocal Rank Fusion (RRF)
Args:
d_idx (list or np.ndarray): ایندکس‌های نتایج dense retriever
d_scores (list or np.ndarray): نمرات dense retriever
s_idx (list or np.ndarray): ایندکس‌های نتایج sparse retriever
s_scores (list or np.ndarray): نمرات sparse retriever
top_k (int): تعداد نتایج نهایی
k_rrf (int): ثابت در فرمول RRF برای کاهش تأثیر رتبه‌های پایین‌تر
Returns:
list: لیست ایندکس‌های ادغام‌شده به ترتیب نمره
"""
combined = {}
# dense retriever
for rank, idx in enumerate(d_idx):
score = 1.0 / (k_rrf + rank)
combined[idx] = combined.get(idx, 0) + score
# sparse retriever
for rank, idx in enumerate(s_idx):
score = 1.0 / (k_rrf + rank)
combined[idx] = combined.get(idx, 0) + score
# مرتب‌سازی نهایی
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
cand_idx = [item[0] for item in sorted_items[:top_k]]
return cand_idx
def rerank(self, query: str, candidate_indices: List[int], passages: List[str], final_k: int) -> List[Tuple[int, float]]:
"""
Rerank candidate passages using a cross-encoder (e.g., MonoT5, MiniLM, etc.).
Args:
query (str): پرسش کاربر
candidate_indices (List[int]): ایندکس‌های کاندیدا (از retriever)
passages (List[str]): کل جملات/پاراگراف‌ها
final_k (int): تعداد نتایج نهایی
Returns:
List[Tuple[int, float]]: لیستی از (ایندکس، امتیاز) برای بهترین نتایج
"""
if final_k <= 0 or not candidate_indices:
return []
# آماده‌سازی جفت‌های (query, passage)
texts = [query] * len(candidate_indices)
pairs = passages
scores: List[float] = []
def _iter_batches(max_bs: int):
bs = max_bs
while bs >= 16: # حداقل batch_size
try:
with torch.inference_mode():
for start in range(0, len(pairs), bs):
batch_texts = texts[start:start + bs]
batch_pairs = pairs[start:start + bs]
inputs = self.tokenizer(
batch_texts,
batch_pairs,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
logits = self.reranker(**inputs).logits.view(-1)
scores.extend(logits.detach().cpu().tolist())
return True
except torch.cuda.OutOfMemoryError:
if torch.cuda.is_available():
torch.cuda.empty_cache()
bs //= 2
return False
# اجرای reranking
success = _iter_batches(max_bs=64)
if not success:
raise RuntimeError("Reranker failed due to CUDA OOM, even with small batch size.")
# مرتب‌سازی نتایج بر اساس نمره
reranked = sorted(
zip(candidate_indices, scores),
key=lambda x: x[1],
reverse=True
)[:final_k]
return reranked
def get_passages(self, cand_idx, content_list):
passages = []
for idx in cand_idx:
passages.append(content_list[idx])
return passages
# --- Search (بدون تغییر) ---
def search(self, query: str, content_list, topk_dense=50, topk_sparse=50,
pre_rerank_k=50, final_k=10):
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
passages = self.get_passages(cand_idx, content_list)
reranked = self.rerank(query, cand_idx, passages, final_k)
return [{"idx": i, "content": self.content_list[i],"prefix": self.prefix_list[i], "rerank_score": score}
for i, score in reranked]
def single_query(query: str):
# query = cleaning(query)
retrived_sections_ids = []
retrived_sections = pipe.search(query, content_list, topk_dense=30, topk_sparse=30, pre_rerank_k=30, final_k=10)
final_similars = ''
for i, row in enumerate(retrived_sections, 1):
id_value = '{' + str(ids[row['idx']]) + '}'
result = f"id: {id_value} \n{row['prefix']} {row['content']}\n"
retrived_sections_ids.append(ids[row['idx']])
final_similars += ''.join(result)
return final_similars, retrived_sections_ids
def find_refrences(llm_answer: str) -> List[str]:
"""
شناسایی شناسه هایی که مدل زبانی، برای تهیه پاسخ از آنها استفاده کرده است
Args:
llm_answer(str): متنی که مدل زبانی تولید کرده است
Returns:
refrence_ids(List[str]): لیستی از شناسه های تشخیص داده شده
"""
pattern = r"\{[^\}]+\}"
refrence_ids = re.findall(pattern, llm_answer)
new_refrences_ids = []
for itm in refrence_ids:
refrence = itm.lstrip('{')
refrence = refrence.lstrip('}')
new_refrences_ids.append(refrence)
# refrence_ids = [item.lstrip('{').rstrip('}') for item in refrence_ids]
return refrence_ids
def replace_refrences(llm_answer: str, refrences_list:List[str]) -> List[str]:
"""
شناسایی شناسه هایی که مدل زبانی، برای تهیه پاسخ از آنها استفاده کرده است
Args:
llm_answer(str): متنی که مدل زبانی تولید کرده است
refrences_list(List[str]): لیست شناسه ماده های مورد استفاده در پاسخ مدل زبانی
Returns:
llm_answer(str), : متن بازسازی شده پاسخ مدل زبانی که شناسه ماده های مورد استفاده در آن، اصلاح شده است
"""
refrences = ''
for index, ref in enumerate(refrences_list,1):
# breakpoint()
llm_answer = llm_answer.replace(ref, f'[{index}]')
# id = ref.lstrip('{')
# id = id.rstrip('}')
# refrences += ''.join(f'[{index}] https://majles.tavasi.ir/entity/detail/view/qsection/{id}\n')
# llm_answer = f'{llm_answer}\n\nمنابع پاسخ‌:\n{refrences.strip()}'
return llm_answer.strip()
# load basic items
content_list, ids, prefix_list, faiss_index = load_faiss_index(FAISS_INDEX_PATH, FAISS_METADATA_PATH)
pipe = HybridRetrieverReranker(content_list, ids, prefix_list, faiss_index, dense_alpha=0.6)
# query preprocess and normalize
normalizer_obj = PersianVectorAnalyzer()
messages = [
{"role": "system", "content": "تو یک دستیار خبره در زمینه حقوق و قوانین مرتبط به آن هستی و می توانی متون حقوقی را به صورت دقیق توضیح بدهی . پاسخ ها باید الزاما به زبان فارسی باشد. پاسخ ها فقط از متون قانونی که در پرامپت وجود دارد استخراج شود."},
]
models = ["gemini-2.5-flash-lite", "gpt-4o-mini"]
def save_result(chat_obj: object) -> bool:
# index result in elastic
pass
def run_chatbot(query:str, chat_id:str):
prompt_status = True
status_text = 'لطفا متن سوال را وارد نمائید'
if query == '':
prompt_status = False
start_time = (datetime.datetime.now())
# در صورتی که وضعیت پرامپت معتبر باشد، وارد فرایند شو
if prompt_status:
result_passages_text, result_passages_ids = single_query(query)
end_retrive = datetime.datetime.now()
print('-'*40)
retrive_duration = (end_retrive - start_time).total_seconds()
print(f'retrive duration: {str(retrive_duration)}')
prompt = f'برای پرسش "{query}" از میان مواد قانونی "{result_passages_text}" .پاسخ مناسب و دقیق را استخراج کن. درصورتی که مطلبی مرتبط با پرسش در متن پیدا نشد، فقط پاسخ بده: "متاسفانه در منابع، پاسخی پیدا نشد!"'
llm_model = ''
for model in models:
try:
llm_model = model
llm_answer = llm_request(prompt, model)
except Exception as error:
error = f'model: {model} \n{error}\n\n'
prompt_status = False
status_text = 'با عرض پوزش، سرویس موقتا در دسترس نیست. لطفا دقایقی دیگر دوباره تلاش نمائید!'
else:
chat_obj = {
'id' : chat_id, # str
'title' : '', # str
'user_id' : '',
'user_query' : query, # str
'model_key' : llm_model, # str
'retrived_passage' : result_passages_text, # str
'retrived_ref_ids' : result_passages_ids, # list[obj]
'prompt_type' : 'question-answer', # str
'retrived_duration' : retrive_duration, # str
'llm_duration' : '0', # str
'full_duration' : '0', # str
'time_create' : str(start_time), # str
'used_ref_ids' : [], # list[str]
'prompt_answer' : '', # str
'status_text' : status_text,
'status' : prompt_status, # or False # bool
}
# آبجکت ایجاد شده با بازگردان
return chat_obj, status_text
llm_answer_duration = (datetime.datetime.now() - end_retrive).total_seconds()
print(f'llm answer duration: {str(llm_answer_duration)}')
used_refrences_in_answer = find_refrences(llm_answer)
llm_answer = replace_refrences(llm_answer, used_refrences_in_answer)
full_prompt_duration = (datetime.datetime.now() - start_time).total_seconds()
print(f'full prompt duration: {full_prompt_duration}')
print('~'*40)
status_text ='پاسخ با موفقیت ایجاد شد'
title = llm_base_request(query)
if title == '':
title = query[0:15]
chat_obj = {
'id' : chat_id, # str
'title' : title, # str
'user_id' : '',
'user_query' : query, # str
'model_key' : llm_model, # str
'retrived_passage' : result_passages_text, # str
'retrived_ref_ids' : result_passages_ids, # list[obj]
'prompt_type' : 'question-answer', # str
'retrived_duration' : retrive_duration, # str
'llm_duration' : llm_answer_duration, # str
'full_duration' : full_prompt_duration, # str
'time_create' : str(start_time), # str
'used_ref_ids' : used_refrences_in_answer, # list[str]
'prompt_answer' : llm_answer, # str
'status_text' : status_text, # str
'status' : True, # or False # bool
}
prev_chat_data = []
with open('./llm-answer/chat-messages.json', mode='r', encoding='utf-8') as file:
prev_chat_data = json.load(file)
prev_chat_data.append(chat_obj)
with open('./llm-answer/chat-messages.json', mode='w', encoding='utf-8') as output:
json.dump(prev_chat_data, output, ensure_ascii=False, indent=2)
# save_result(chat_obj)
# ایجاد آبجکت بازگشتی به فرانت
# chat_obj.pop('retrived_passage')
# chat_obj.pop('prompt_type')
return chat_obj
@chatbot.post("/credit_refresh")
def credit_refresh():
"""
Returns remained credit
"""
url = "https://api.avalai.ir/user/credit"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {get_key()}"
}
remained_credit = requests.get(url, headers=headers)
with open('./llm-answer/credit.txt','w') as file:
file.write(str(remained_credit.json()['remaining_irt']))
return str(remained_credit.json()['remaining_irt'])
def create_chat_id():
date = str((datetime.datetime.now())).replace(' ','-').replace(':','').replace('.','-')
print('date ', date )
chat_id = f'{date}-{random.randint(100000, 999999)}'
print('chat_id ', chat_id )
return chat_id
print('#'*19)
print('-Chatbot is Ready!!!!!-')
print('#'*19)
# تعریف مدل داده‌ها برای درخواست‌های API
class Query(BaseModel):
query: str
# مسیر API برای اجرا کردن run_chatbot
@chatbot.post("/run_chatbot")
def run_chat(query: Query):
print('query ', query )
chat_id = create_chat_id()
print('query.query ', query.query )
answer = run_chatbot(query.query, chat_id)
credit_refresh()
return {"answer": answer}
# uvicorn src.app:app --reload
if __name__ == "__main__":
# query = 'در قانون حمایت از خانواده و جوانی جمعیت چه خدماتی در نظر گرفته شده است؟'
while True:
query = input('enter your qustion:')
if query == '':
print('لطفا متن سوال را وارد نمائید')
continue
start = (datetime.datetime.now())
# result = test_dataset()
result = single_query(query)
end_retrive = datetime.datetime.now()
print('-'*40)
print(f'retrive duration: {(end_retrive - start).total_seconds()}')
prompt = f'برای پرسش "{query}" از میان مواد قانونی "{result}" .پاسخ مناسب و دقیق را استخراج کن. درصورتی که مطلبی مرتبط با پرسش در متن پیدا نشد، فقط پاسخ بده: "متاسفانه در منابع، پاسخی پیدا نشد!"'
llm_answer = llm_request(prompt)
print('-'*40)
print(f'llm duration: {(datetime.datetime.now() - end_retrive).total_seconds()}')
refrences = ''
recognized_refrences = find_refrences(llm_answer)
llm_answer = replace_refrences(llm_answer, recognized_refrences)
with open('./llm-answer/result.txt', mode='a+', encoding='utf-8') as file:
result_message = f'متن پرامپت: {query.strip()}\n\nپاسخ: {llm_answer} \n----------------------------------------------------------\n'
file.write(result_message)
with open('./llm-answer/passages.txt', mode='a+', encoding='utf-8') as file:
result_message = f'متن پرامپت: {query.strip()}\n\مواد مشابه: {result} \n----------------------------------------------------------\n'
file.write(result_message)
print('----------------------------------------------------------')
print(f'full duration: {(datetime.datetime.now() - start).total_seconds()}')
print('----------------------------------------------------------')
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')