1009 lines
42 KiB
Python
Executable File
1009 lines
42 KiB
Python
Executable File
import json
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
import faiss
|
||
from typing import List, Tuple
|
||
from sentence_transformers import SentenceTransformer
|
||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import datetime
|
||
import re
|
||
import random
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from util.embedder_sbert_qavanin_285k import PersianVectorAnalyzer
|
||
# from util.normalizer import cleaning
|
||
from fastapi import FastAPI ,Header
|
||
from pydantic import BaseModel
|
||
# LLM Libs
|
||
from openai import OpenAI
|
||
from langchain_openai import ChatOpenAI # pip install -U langchain_openai
|
||
import requests
|
||
# from FlagEmbedding import FlagReranker # deldar-reranker-v2
|
||
import aiofiles
|
||
from openai import AsyncOpenAI
|
||
|
||
LLM_URL = "http://172.16.29.102:8001/v1/"
|
||
|
||
|
||
|
||
# -------------------
|
||
# مدلها و مسیر دادهsrc/app/qavanin-faiss/faiss_index_qavanin_285k_metadata.json
|
||
# -------------------/src/app/qavanin-faiss
|
||
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
||
FAISS_INDEX_PATH = "/src/app/data/qavanin-faiss/faiss_index_qavanin_285k.index"
|
||
FAISS_METADATA_PATH = "/src/app/data/qavanin-faiss/faiss_index_qavanin_285k_metadata.json"
|
||
|
||
RERANK_BATCH = int(os.environ.get("RERANK_BATCH", 256))
|
||
# print(f'RERANK_BATCH: {RERANK_BATCH}')
|
||
determine_refrence = '''شناسه هر ماده قانونی در ابتدای آن و با فرمت "id: {idvalue}" آمده است که id-value همان شناسه ماده است. بازای هربخش از پاسخی که تولید می شود، ضروری است شناسه ماده ای که در ایجاد پاسخ از آن استفاده شده، در انتهای پاراگراف یا جمله مربوطه با فرمت {idvalue} اضافه شود. همیشه idvalue با رشته "qs" شروع می شود'''
|
||
messages = [
|
||
# {
|
||
# "role": "system",
|
||
# "content": "تو یک دستیار خبره در زمینه حقوق و قوانین مرتبط به آن هستی و می توانی متون حقوقی را به صورت دقیق توضیح بدهی . پاسخ ها باید الزاما به زبان فارسی باشد. پاسخ ها فقط از متون قانونی که در پرامپت وجود دارد استخراج شود. پاسخ تولید شده باید کاملا ساده و بدون هیچ مارک داون یا علائم افزوده ای باشد. لحن متن باید رسمی باشد.",
|
||
# },
|
||
{"role": "developer", "content": determine_refrence},
|
||
]
|
||
|
||
models = ["gpt-4o-mini" ,"gemini-2.5-flash-lite", "deepseek-chat"]
|
||
normalizer_obj = PersianVectorAnalyzer()
|
||
pipe = None
|
||
content_list, ids, prefix_list, faiss_index = [], [], [], []
|
||
|
||
path_log='./data/llm-answer/'
|
||
|
||
async def get_key():
|
||
key = 'aa-fdh9d847ANcBxQCBTZD5hrrAdl0UrPEnJOScYmOncrkagYPf'
|
||
return key
|
||
|
||
def load_faiss_index(index_path: str, metadata_path: str):
|
||
"""بارگذاری ایندکس FAISS و متادیتا (لیست جملات + عناوین)."""
|
||
index = faiss.read_index(index_path)
|
||
|
||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||
metadata = json.load(f)
|
||
|
||
content_list, ids, prefix_list = [], [], []
|
||
for item in metadata:
|
||
content_list.append(item["content"])
|
||
ids.append(item["id"])
|
||
prefix_list.append(item["prefix"])
|
||
|
||
return content_list, ids, prefix_list, index
|
||
|
||
async def get_client():
|
||
url = "https://api.avalai.ir/v1"
|
||
# key = 'aa-4tvAEazUBovEN1i7i7tdl1PR93OaWXs6hMflR4oQbIIA4K7Z'
|
||
|
||
client = OpenAI(
|
||
api_key=await get_key(),
|
||
base_url=url, # آدرس پایه
|
||
)
|
||
|
||
return client
|
||
|
||
async def llm_base_request(system_prompt, user_prompt):
|
||
client = await get_client() # فرض میکنیم get_client یک متد async است
|
||
base_messages = []
|
||
try:
|
||
if system_prompt:
|
||
base_messages.append({
|
||
"role": "system",
|
||
"content": system_prompt
|
||
})
|
||
|
||
base_messages.append({
|
||
"role": "user",
|
||
"content": user_prompt
|
||
})
|
||
for model in models:
|
||
response = client.chat.completions.create( # متد create به صورت async فراخوانی میشود
|
||
messages=base_messages,
|
||
model=model
|
||
)
|
||
answer = response.choices[0].message.content
|
||
cost = response.estimated_cost['irt']
|
||
break
|
||
|
||
except Exception as error:
|
||
# برای مدیریت خطاها، میتوانید فایلنویسی را به صورت async انجام دهید (در صورت نیاز)
|
||
async with aiofiles.open(path_log+'error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {user_prompt.strip()}\nerror:{error} \n------------------------------\n'
|
||
await file.write(error_message) # فایلنویسی async
|
||
|
||
return '', 0
|
||
|
||
return answer, cost
|
||
|
||
def llm_base_request2(system_prompt, user_prompt):
|
||
client = get_client()
|
||
base_messages = []
|
||
try:
|
||
if system_prompt:
|
||
base_messages.append(system_prompt)
|
||
base_messages.append({
|
||
"role": "user",
|
||
"content": user_prompt
|
||
})
|
||
for model in models:
|
||
response = client.chat.completions.create(
|
||
messages = base_messages,
|
||
model= model)
|
||
answer = response.choices[0].message.content
|
||
cost = response.estimated_cost['irt']
|
||
break
|
||
|
||
except Exception as error:
|
||
with open(path_log+'error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
|
||
file.write(error_message)
|
||
|
||
return '', 0
|
||
|
||
return answer, cost
|
||
|
||
async def oss_base_request(sys_prompt, user_prompt):
|
||
base_messages = []
|
||
try:
|
||
if sys_prompt:
|
||
base_messages.append({
|
||
"role": "system",
|
||
"content": sys_prompt
|
||
})
|
||
|
||
base_messages.append({
|
||
"role": "user",
|
||
"content": user_prompt
|
||
})
|
||
response = await process_item(base_messages, reasoning_effort='low', temperature=0.1, max_tokens=40)
|
||
|
||
if response[0]:
|
||
answer = response[1]
|
||
else:
|
||
answer = ''
|
||
cost = 0
|
||
|
||
except Exception as error:
|
||
# برای مدیریت خطاها، میتوانید فایلنویسی را به صورت async انجام دهید (در صورت نیاز)
|
||
async with aiofiles.open(path_log+'error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {user_prompt.strip()}\nerror:{error} \n------------------------------\n'
|
||
await file.write(error_message) # فایلنویسی async
|
||
|
||
return '', 0
|
||
|
||
return answer, cost
|
||
|
||
async def oss_request(query):
|
||
|
||
if query == '':
|
||
return 'لطفا متن سوال را وارد نمائید', 0
|
||
|
||
try:
|
||
messages.append({"role": "user", "content": query})
|
||
print(f'final prompt request attempt oss')
|
||
response = await process_item(messages= messages, reasoning_effort='low') # reasoning_effort='high'
|
||
# print(response)
|
||
if response[0]:
|
||
answer = response[1]
|
||
else:
|
||
answer = 'متاسفانه پاسخی دریافت نشد'
|
||
cost_prompt = 0
|
||
# پاسخ را هم به سابقه اضافه میکنیم
|
||
# messages.append({"role": "assistant", "content": answer})
|
||
|
||
response_dict = {}
|
||
response_dict['output'] = str(response)
|
||
async with aiofiles. open(path_log+'messages.json', mode='w', encoding='utf-8') as output:
|
||
await output.write(json.dumps(response_dict, ensure_ascii=False, indent=2))
|
||
print('oss response created')
|
||
async with aiofiles.open(path_log+'chat-objs.txt', mode='a+', encoding='utf-8') as file:
|
||
response_value = '0'
|
||
await file.write(response_value) # estimated_cost
|
||
|
||
except Exception as error:
|
||
print(f'error-in-llm.txt writing ...')
|
||
async with aiofiles.open(path_log+'error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
|
||
await file.write(error_message)
|
||
|
||
return 'با عرض پوزش؛ متاسفانه خطایی رخ داده است. لطفا لحظاتی دیگر دوباره تلاش نمائید', 0
|
||
print('================')
|
||
print(f'len messages: {len(messages)}')
|
||
print('================')
|
||
return answer, cost_prompt
|
||
|
||
async def llm_request(query, model):
|
||
|
||
if query == '':
|
||
return 'لطفا متن سوال را وارد نمائید', 0
|
||
|
||
client = await get_client()
|
||
try:
|
||
messages.append({"role": "user", "content": query})
|
||
response = client.chat.completions.create(
|
||
messages = messages,
|
||
model = model,
|
||
temperature = 0.3) # "gpt-4o", "gpt-4o-mini", "deepseek-chat" , "gemini-2.0-flash", gemini-2.5-flash-lite
|
||
# gpt-4o : 500
|
||
# gpt-4o-mini : 34
|
||
# deepseek-chat: : 150
|
||
# gemini-2.0-flash : error
|
||
# cf.gemma-3-12b-it : 1
|
||
# gemini-2.5-flash-lite : 35 خیلی خوب
|
||
|
||
answer = response.choices[0].message.content
|
||
# print('$'*50)
|
||
# print(f'answer: {answer}')
|
||
# print('$'*50)
|
||
cost_prompt = response.estimated_cost['irt']
|
||
# print('$'*50)
|
||
# print(f'answer: {cost_prompt}')
|
||
# print('$'*50)
|
||
# پاسخ را هم به سابقه اضافه میکنیم
|
||
# messages.append({"role": "assistant", "content": answer})
|
||
# print(f'type(response): {type(response)}')
|
||
# print(f'response: {response}')
|
||
response_dict = {}
|
||
response_dict['output'] = str(response)
|
||
async with aiofiles. open(path_log+'messages.json', mode='w', encoding='utf-8') as output:
|
||
await output.write(json.dumps(response_dict, ensure_ascii=False, indent=2))
|
||
print('llm response created')
|
||
async with aiofiles.open(path_log+'chat-objs.txt', mode='a+', encoding='utf-8') as file:
|
||
response_value = f"{response.estimated_cost['irt']}\n-------------------------------\n\n"
|
||
await file.write(response_value) # estimated_cost
|
||
|
||
except Exception as error:
|
||
print(f'error-in-llm.txt writing ...')
|
||
async with aiofiles.open(path_log+'error-in-llm.txt', mode='a+', encoding='utf-8') as file:
|
||
error_message = f'\n\nquery: {query.strip()}\nerror:{error} \n-------------------------------\n'
|
||
await file.write(error_message)
|
||
|
||
return 'با عرض پوزش؛ متاسفانه خطایی رخ داده است. لطفا لحظاتی دیگر دوباره تلاش نمائید', 0
|
||
print('================')
|
||
print(f'len messages: {len(messages)}')
|
||
print('================')
|
||
return answer, cost_prompt
|
||
|
||
class HybridRetrieverReranker:
|
||
__slots__ = (
|
||
"device", "content_list", "ids", "prefix_list", "N", "embedder", "faiss_index",
|
||
"vectorizer", "tfidf_matrix", "tokenizer", "reranker", "dense_alpha"
|
||
)
|
||
|
||
def __init__(self, content_list: List[str],ids: List[str], prefix_list: List[str], faiss_index,
|
||
dense_alpha: float = 0.6, device: str = None):
|
||
|
||
if device is None:
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
self.device = device
|
||
|
||
self.content_list = content_list
|
||
self.ids = ids
|
||
self.prefix_list = prefix_list
|
||
self.faiss_index = faiss_index
|
||
self.N = len(content_list)
|
||
|
||
# Dense
|
||
self.embedder = SentenceTransformer(EMBED_MODEL,cache_folder='/src/MODELS', device=self.device)
|
||
#self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device)
|
||
|
||
# Sparse (مثل قبل برای حفظ خروجی)
|
||
self.vectorizer = TfidfVectorizer(
|
||
analyzer="word",
|
||
ngram_range=(1, 2),
|
||
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
|
||
)
|
||
self.tfidf_matrix = self.vectorizer.fit_transform(self.content_list)
|
||
|
||
# Reranker
|
||
self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL,cache_dir='/src/MODELS', use_fast=True)
|
||
# self.reranker = FlagReranker(RERANKER_MODEL,cache_dir="/src/MODELS", use_fp16=True)
|
||
self.reranker = AutoModelForSequenceClassification.from_pretrained(
|
||
RERANKER_MODEL
|
||
).to(self.device)
|
||
|
||
self.dense_alpha = float(dense_alpha)
|
||
|
||
# --- Dense (FAISS) ---
|
||
def dense_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([], dtype=np.float32)
|
||
|
||
q_emb = self.embedder.encode(query, convert_to_numpy=True).astype(np.float32)
|
||
D, I = self.faiss_index.search(np.expand_dims(q_emb, axis=0), top_k)
|
||
|
||
return I[0].tolist(), D[0]
|
||
|
||
# --- Sparse ---
|
||
def sparse_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([], dtype=np.float32)
|
||
k = min(top_k, self.N)
|
||
q_vec = self.vectorizer.transform([query])
|
||
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
|
||
idx = np.argpartition(-sims, kth=k-1)[:k]
|
||
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
|
||
return idx.tolist(), sims[idx]
|
||
|
||
# --- Utils ---
|
||
@staticmethod
|
||
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
|
||
if arr.size == 0:
|
||
return arr
|
||
a_min = arr.min()
|
||
a_max = arr.max()
|
||
rng = a_max - a_min
|
||
if rng < 1e-12:
|
||
return np.zeros_like(arr)
|
||
return (arr - a_min) / rng
|
||
|
||
def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
|
||
"""
|
||
ادغام نتایج دو retriever (dense و sparse) با استفاده از Reciprocal Rank Fusion (RRF)
|
||
|
||
Args:
|
||
d_idx (list or np.ndarray): ایندکسهای نتایج dense retriever
|
||
d_scores (list or np.ndarray): نمرات dense retriever
|
||
s_idx (list or np.ndarray): ایندکسهای نتایج sparse retriever
|
||
s_scores (list or np.ndarray): نمرات sparse retriever
|
||
top_k (int): تعداد نتایج نهایی
|
||
k_rrf (int): ثابت در فرمول RRF برای کاهش تأثیر رتبههای پایینتر
|
||
|
||
Returns:
|
||
list: لیست ایندکسهای ادغامشده به ترتیب نمره
|
||
"""
|
||
combined = {}
|
||
|
||
# dense retriever
|
||
for rank, idx in enumerate(d_idx):
|
||
score = 1.0 / (k_rrf + rank)
|
||
combined[idx] = combined.get(idx, 0) + score
|
||
|
||
# sparse retriever
|
||
for rank, idx in enumerate(s_idx):
|
||
score = 1.0 / (k_rrf + rank)
|
||
combined[idx] = combined.get(idx, 0) + score
|
||
|
||
# مرتبسازی نهایی
|
||
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
|
||
cand_idx = [item[0] for item in sorted_items[:top_k]]
|
||
|
||
return cand_idx
|
||
|
||
def rerank2(self, query: str, candidate_indices: List[int], passages: List[str], final_k:int=4):
|
||
z_results = [[query, sentence] for sentence in passages]
|
||
# The scores map into 0-1 by set "normalize=True", which will apply sigmoid function to the score
|
||
scores = self.reranker.compute_score(z_results, normalize=True)
|
||
s_results = sorted(zip(scores, z_results, candidate_indices), key=lambda x: x[0], reverse=True)
|
||
s_results2 = s_results[:final_k]
|
||
results = [[i[0], i[1][1], i[2]] for i in s_results2]
|
||
print('%'*50)
|
||
print('%'*50)
|
||
print(results)
|
||
with open(path_log+'reranker-result.txt', mode='a+', encoding='utf-8') as file:
|
||
for item in results:
|
||
file.write(f'{item}\n')
|
||
print('%'*50)
|
||
print('%'*50)
|
||
return results
|
||
|
||
def rerank(self, query: str, candidate_indices: List[int], passages: List[str], final_k: int) -> List[Tuple[int, float]]:
|
||
"""
|
||
Rerank candidate passages using a cross-encoder (e.g., MonoT5, MiniLM, etc.).
|
||
|
||
Args:
|
||
query (str): پرسش کاربر
|
||
candidate_indices (List[int]): ایندکسهای کاندیدا (از retriever)
|
||
passages (List[str]): کل جملات/پاراگرافها
|
||
final_k (int): تعداد نتایج نهایی
|
||
|
||
Returns:
|
||
List[Tuple[int, float]]: لیستی از (ایندکس، امتیاز) برای بهترین نتایج
|
||
"""
|
||
if final_k <= 0 or not candidate_indices:
|
||
return []
|
||
|
||
# آمادهسازی جفتهای (query, passage)
|
||
texts = [query] * len(candidate_indices)
|
||
pairs = passages
|
||
|
||
scores: List[float] = []
|
||
|
||
def _iter_batches(max_bs: int):
|
||
bs = max_bs
|
||
while bs >= 16: # حداقل batch_size
|
||
try:
|
||
with torch.inference_mode():
|
||
for start in range(0, len(pairs), bs):
|
||
batch_texts = texts[start:start + bs]
|
||
batch_pairs = pairs[start:start + bs]
|
||
inputs = self.tokenizer(
|
||
batch_texts,
|
||
batch_pairs,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512,
|
||
return_tensors="pt",
|
||
).to(self.device)
|
||
|
||
logits = self.reranker(**inputs).logits.view(-1)
|
||
scores.extend(logits.detach().cpu().tolist())
|
||
return True
|
||
except torch.cuda.OutOfMemoryError:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
bs //= 2
|
||
return False
|
||
|
||
# اجرای reranking
|
||
success = _iter_batches(max_bs=64)
|
||
if not success:
|
||
raise RuntimeError("Reranker failed due to CUDA OOM, even with small batch size.")
|
||
|
||
# مرتبسازی نتایج بر اساس نمره
|
||
reranked = sorted(
|
||
zip(candidate_indices, scores),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:final_k]
|
||
|
||
return reranked
|
||
|
||
def get_passages(self, cand_idx, content_list):
|
||
passages = []
|
||
for idx in cand_idx:
|
||
passages.append(content_list[idx])
|
||
|
||
return passages
|
||
|
||
# --- Search (بدون تغییر) ---
|
||
def search_base(self, query: str, content_list, topk_dense=50, topk_sparse=50,
|
||
pre_rerank_k=50, final_k=10):
|
||
start_time = datetime.datetime.now()
|
||
|
||
# form embedder model
|
||
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
|
||
dense_retrieve_end = datetime.datetime.now()
|
||
# print('@'*50)
|
||
# print(f'dense_retrieve_duration: {(dense_retrieve_end - start_time).total_seconds()}')
|
||
|
||
# from tfidf_matrix
|
||
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
|
||
sparse_retrieve_end = datetime.datetime.now()
|
||
# print(f'sparse_retrieve_duration: {(sparse_retrieve_end - dense_retrieve_end).total_seconds()}')
|
||
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
|
||
fuse_end = datetime.datetime.now()
|
||
# print(f'fuse_duration: {(fuse_end - sparse_retrieve_end).total_seconds()}')
|
||
passages = self.get_passages(cand_idx, content_list)
|
||
get_passages_end = datetime.datetime.now()
|
||
# print(f'get_passages_duration: {(get_passages_end - fuse_end).total_seconds()}')
|
||
reranked = self.rerank(query, cand_idx, passages, final_k) # rerank2
|
||
rerank_end = datetime.datetime.now()
|
||
# print(f'rerank_duration: {(rerank_end - get_passages_end).total_seconds()}')
|
||
# print('@'*50)
|
||
return [{"idx": i, "content": self.content_list[i],"prefix": self.prefix_list[i], "rerank_score": score}
|
||
for i, score in reranked]
|
||
|
||
async def search_query(query: str):
|
||
|
||
# query = cleaning(query)
|
||
retrived_sections_ids = []
|
||
|
||
retrived_sections = pipe.search_base(query, content_list, topk_dense=100, topk_sparse=100, pre_rerank_k=100, final_k=10)
|
||
final_similars = ''
|
||
for i, row in enumerate(retrived_sections, 1):
|
||
id_value = '{' + str(ids[row['idx']]) + '}'
|
||
result = f"id: {id_value} \n{row['prefix']} {row['content']}\n"
|
||
retrived_sections_ids.append(ids[row['idx']])
|
||
final_similars += ''.join(result)
|
||
|
||
return final_similars, retrived_sections_ids
|
||
|
||
async def find_refrences(llm_answer: str) -> List[str]:
|
||
"""
|
||
شناسایی شناسه هایی که مدل زبانی، برای تهیه پاسخ از آنها استفاده کرده است
|
||
|
||
Args:
|
||
llm_answer(str): متنی که مدل زبانی تولید کرده است
|
||
|
||
Returns:
|
||
refrence_ids(List[str]): لیستی از شناسه های تشخیص داده شده
|
||
"""
|
||
pattern = r"\{[^\}]+\}"
|
||
# pattern = r"(?:\{([^\}]+)\}|【([^】]+)】)"
|
||
refrence_ids = re.findall(pattern, llm_answer)
|
||
new_refrences_ids = []
|
||
for itm in refrence_ids:
|
||
# print(itm)
|
||
refrence = itm.lstrip('{')
|
||
refrence = refrence.lstrip('}')
|
||
new_refrences_ids.append(refrence)
|
||
|
||
refrence_ids = [item.lstrip('{').rstrip('}') for item in refrence_ids]
|
||
|
||
|
||
|
||
return refrence_ids
|
||
|
||
async def replace_refrences(llm_answer: str, refrences_list:List[str]) -> List[str]:
|
||
"""
|
||
شناسایی شناسه هایی که مدل زبانی، برای تهیه پاسخ از آنها استفاده کرده است
|
||
|
||
Args:
|
||
llm_answer(str): متنی که مدل زبانی تولید کرده است
|
||
refrences_list(List[str]): لیست شناسه ماده های مورد استفاده در پاسخ مدل زبانی
|
||
Returns:
|
||
llm_answer(str), : متن بازسازی شده پاسخ مدل زبانی که شناسه ماده های مورد استفاده در آن، اصلاح شده است
|
||
"""
|
||
# refrences = ''
|
||
for index, ref in enumerate(refrences_list,1):
|
||
new_ref = '{' + str(ref) + '}'
|
||
llm_answer = llm_answer.replace(new_ref, f'[«{str(index)}»](https://majles.tavasi.ir/entity/detail/view/qsection/{ref}) ')
|
||
# id = ref.lstrip('{')
|
||
# id = id.rstrip('}')
|
||
# refrences += ''.join(f'[{index}] https://majles.tavasi.ir/entity/detail/view/qsection/{id}\n')
|
||
|
||
# llm_answer = f'{llm_answer}\n\nمنابع پاسخ:\n{refrences.strip()}'
|
||
return llm_answer.strip()
|
||
|
||
def initial_model():
|
||
global pipe
|
||
global content_list, ids, prefix_list, faiss_index
|
||
|
||
if not pipe :
|
||
# load basic items
|
||
content_list, ids, prefix_list, faiss_index = load_faiss_index(FAISS_INDEX_PATH, FAISS_METADATA_PATH)
|
||
pipe = HybridRetrieverReranker(content_list, ids, prefix_list, faiss_index, dense_alpha=0.6)
|
||
# query preprocess and normalize
|
||
|
||
|
||
def save_result(chat_obj: object) -> bool:
|
||
# index result in elastic
|
||
pass
|
||
|
||
async def get_title_user_prompt(query: str):
|
||
"""
|
||
get a query and prepare a prompt to generate title based on that
|
||
"""
|
||
title_prompt = f'برای متن {query} یک عنوان با معنا که بین 3 تا 6 کلمه داشته باشد، در قالب یک رشته متن ایجاد کن. سبک و لحن عنوان، حقوقی و کاملا رسمی باشد. عنوان تولید شده کاملا ساده و بدون هیچ مارک داون یا علائم افزوده ای باشد. غیر از عنوان، به هیچ وجه توضیح اضافه ای در قبل یا بعد آن اضافه نکن.'
|
||
return title_prompt
|
||
|
||
async def get_title_system_prompt():
|
||
"""
|
||
returns a system prompt due to generate title
|
||
"""
|
||
title_system_prompt = f'تو یک دستیار حقوقی هستی و می توانی متون و سوالات حقوقی را به زبان ساده و دقیق توضیح بدهی.'
|
||
return title_system_prompt
|
||
|
||
|
||
async def ask_chatbot_avalai(query:str, chat_id:str):
|
||
print('ask avalai func')
|
||
prompt_status = True
|
||
llm_model = ''
|
||
llm_answer = ''
|
||
cost_prompt = 0
|
||
cost_title = 0
|
||
status_text = 'لطفا متن سوال را وارد نمائید'
|
||
if query == '':
|
||
prompt_status = False
|
||
|
||
# در صورتی که وضعیت پرامپت معتبر باشد، وارد فرایند شو
|
||
if prompt_status:
|
||
|
||
before_title_time = datetime.datetime.now()
|
||
title_system_prompt = await get_title_system_prompt()
|
||
title_user_prompt = await get_title_user_prompt(query)
|
||
title, cost_title = await llm_base_request(title_system_prompt, title_user_prompt)
|
||
# title, cost_title = await oss_base_request(title_system_prompt, title_user_prompt)
|
||
if not title:
|
||
title = query
|
||
|
||
title_prompt_duration = (datetime.datetime.now() - before_title_time).total_seconds()
|
||
|
||
if title == '':
|
||
title = query.split()[0:10]
|
||
|
||
start_time = (datetime.datetime.now())
|
||
result_passages_text, result_passages_ids = await search_query(query)
|
||
end_retrive = datetime.datetime.now()
|
||
print('-'*40)
|
||
print(f'title_prompt_duration: {title_prompt_duration}')
|
||
retrive_duration = (end_retrive - start_time).total_seconds()
|
||
print(f'retrive duration: {str(retrive_duration)}')
|
||
|
||
prompt = f'''برای پرسش "{query}" از میان متون قانونی زیر، پاسخ مناسب و دقیق را استخراج کن.
|
||
متون قانونی:
|
||
"{result_passages_text}"
|
||
'''
|
||
|
||
|
||
|
||
for model in models:
|
||
print(f'credit calculate')
|
||
before_prompt_credit = await credit_refresh()
|
||
print(f'credit calculate finish')
|
||
|
||
llm_model = model
|
||
print(f'using model: {model}')
|
||
try:
|
||
llm_answer, cost_prompt = await llm_request(prompt, model)
|
||
# llm_answer, cost_prompt = await oss_request(prompt)
|
||
break
|
||
except Exception as error:
|
||
print(f'error in ask-chatbot-avalai model:{model}')
|
||
after_prompt_credit = await credit_refresh()
|
||
prompt_cost = int(before_prompt_credit) - int(after_prompt_credit)
|
||
error = f'model: {model} \n{error}\n\n'
|
||
print('+++++++++++++++++')
|
||
print(f'llm-error.txt writing error: {error}')
|
||
print('+++++++++++++++++')
|
||
async with aiofiles.open(path_log+'llm-error.txt', mode='a+', encoding='utf-8') as file:
|
||
await file.write(error)
|
||
prompt_status = False
|
||
status_text = 'با عرض پوزش، سرویس موقتا در دسترس نیست. لطفا دقایقی دیگر دوباره تلاش نمائید!'
|
||
continue
|
||
|
||
|
||
# حالتی که وضعیت پرامپت، نامعتبر باشد، یک شی با مقادیر زیر برگردانده می شود
|
||
else:
|
||
chat_obj = {
|
||
'id' : chat_id, # str
|
||
'title' : '', # str
|
||
'user_id' : '',
|
||
'user_query' : query, # str
|
||
'model_key' : llm_model, # str
|
||
'retrived_passage' : '', # str
|
||
'retrived_ref_ids' : '', # list[obj]
|
||
'prompt_type' : 'question-answer', # str
|
||
'retrived_duration' : '', # str
|
||
'llm_duration' : '0', # str
|
||
'full_duration' : '0', # str
|
||
'cost_prompt' : str(cost_prompt), # str
|
||
'cost_title' : str(cost_title), # str
|
||
'cost_total' : str(cost_prompt + cost_title), # str
|
||
'time_create' : str(start_time), # str
|
||
'used_ref_ids' : [], # list[str]
|
||
'prompt_answer' : '', # str
|
||
'status_text' : status_text,
|
||
'status' : prompt_status, # or False # bool
|
||
}
|
||
|
||
# بازگرداندن آبجکت ایجاد شده
|
||
return chat_obj, status_text
|
||
|
||
llm_answer_duration = (datetime.datetime.now() - end_retrive).total_seconds()
|
||
print(f'llm answer duration: {str(llm_answer_duration)}')
|
||
|
||
used_refrences_in_answer = await find_refrences(llm_answer)
|
||
llm_answer = await replace_refrences(llm_answer, used_refrences_in_answer)
|
||
|
||
full_prompt_duration = (datetime.datetime.now() - start_time).total_seconds()
|
||
print(f'full prompt duration: {full_prompt_duration}')
|
||
print('~'*40)
|
||
|
||
status_text ='پاسخ با موفقیت ایجاد شد'
|
||
|
||
print(f'cost_prompt: {cost_prompt}')
|
||
print(f'cost_title: {cost_title}')
|
||
chat_obj = {
|
||
'id' : chat_id, # str
|
||
'title' : title, # str
|
||
'user_id' : '', #
|
||
'user_query' : query, # str
|
||
'model_key' : llm_model, # str
|
||
'retrived_passage' : result_passages_text, # str
|
||
'retrived_ref_ids' : result_passages_ids, # list[obj]
|
||
'prompt_type' : 'question-answer', # str
|
||
'retrived_duration' : retrive_duration, # str
|
||
'llm_duration' : llm_answer_duration, # str
|
||
'full_duration' : full_prompt_duration, # str
|
||
'cost_prompt' : str(cost_prompt), # str
|
||
'cost_title' : str(cost_title), # str
|
||
'cost_total' : str(cost_prompt + cost_title), # str
|
||
'time_create' : str(start_time), # str
|
||
'used_ref_ids' : used_refrences_in_answer, # list[str]
|
||
'prompt_answer' : llm_answer, # str
|
||
'status_text' : status_text, # str
|
||
'status' : True, # or False # bool
|
||
}
|
||
prev_chat_data = []
|
||
number = 1
|
||
try:
|
||
async with aiofiles.open(path_log+'chat-messages{number}.json', mode='r', encoding='utf-8') as file:
|
||
content = await file.read()
|
||
prev_chat_data = json.loads(content)
|
||
except:
|
||
number += 1
|
||
|
||
prev_chat_data.append(chat_obj)
|
||
async with aiofiles.open(path_log+'chat-messages{number}.json', mode='w', encoding='utf-8') as output:
|
||
await output.write(json.dumps(prev_chat_data, ensure_ascii=False, indent=2))
|
||
|
||
async with aiofiles.open(path_log+'chat-messages-answer{number}.txt', mode='a+', encoding='utf-8') as output:
|
||
await output.write(f'{chat_obj}\n+++++++++++++++++++++++++++\n')
|
||
|
||
# save_result(chat_obj)
|
||
|
||
# ایجاد آبجکت بازگشتی به فرانت
|
||
# chat_obj.pop('retrived_passage')
|
||
# chat_obj.pop('prompt_type')
|
||
|
||
print('~'*40)
|
||
|
||
return chat_obj
|
||
|
||
async def ask_chatbot(query:str, chat_id:str):
|
||
print('ask oss func')
|
||
prompt_status = True
|
||
llm_model = 'gpt.oss.120b'
|
||
llm_answer = ''
|
||
cost_prompt = 0
|
||
cost_title = 0
|
||
status_text = 'لطفا متن سوال را وارد نمائید'
|
||
if query == '':
|
||
prompt_status = False
|
||
|
||
|
||
|
||
# در صورتی که وضعیت پرامپت معتبر باشد، وارد فرایند شو
|
||
if prompt_status:
|
||
|
||
before_title_time = datetime.datetime.now()
|
||
title_system_prompt = await get_title_system_prompt()
|
||
title_user_prompt = await get_title_user_prompt(query)
|
||
title = ''
|
||
# title, cost_title = await llm_base_request(title_system_prompt, title_user_prompt)
|
||
title, cost_title = await oss_base_request(title_system_prompt, title_user_prompt)
|
||
if not title:
|
||
title = query
|
||
|
||
title_prompt_duration = (datetime.datetime.now() - before_title_time).total_seconds()
|
||
print('-'*40)
|
||
print(f'title_prompt_duration: {title_prompt_duration}')
|
||
|
||
if title == '':
|
||
title = query.split()[0:10]
|
||
|
||
start_time = (datetime.datetime.now())
|
||
result_passages_text, result_passages_ids = await search_query(query)
|
||
end_retrive = datetime.datetime.now()
|
||
retrive_duration = (end_retrive - start_time).total_seconds()
|
||
print(f'retrive duration: {str(retrive_duration)}')
|
||
|
||
prompt = f''' برای پرسش "{query}" از میان متون قانونی زیر، پاسخ مناسب و دقیق را استخراج کن.
|
||
متون قانونی:
|
||
"{result_passages_text}"
|
||
'''
|
||
try:
|
||
llm_answer, cost_prompt = await oss_request(prompt)
|
||
|
||
except Exception as error:
|
||
# after_prompt_credit = credit_refresh()
|
||
# prompt_cost = int(before_prompt_credit) - int(after_prompt_credit)
|
||
error = f'model: gpt.oss.120b \n{error}\n\n'
|
||
print('+++++++++++++++++')
|
||
print(f'llm-error.txt writing error: {error}')
|
||
print('+++++++++++++++++')
|
||
async with aiofiles.open(path_log+'llm-error.txt', mode='a+', encoding='utf-8') as file:
|
||
await file.write(error)
|
||
prompt_status = False
|
||
status_text = 'با عرض پوزش، سرویس موقتا در دسترس نیست. لطفا دقایقی دیگر دوباره تلاش نمائید!'
|
||
|
||
# حالتی که وضعیت پرامپت، نامعتبر باشد، یک شی با مقادیر زیر برگردانده می شود
|
||
else:
|
||
chat_obj = {
|
||
'id' : chat_id, # str
|
||
'title' : '', # str
|
||
'user_id' : '',
|
||
'user_query' : query, # str
|
||
'model_key' : llm_model, # str
|
||
'retrived_passage' : '', # str
|
||
'retrived_ref_ids' : '', # list[obj]
|
||
'prompt_type' : 'question-answer', # str
|
||
'retrived_duration' : '', # str
|
||
'llm_duration' : '0', # str
|
||
'full_duration' : '0', # str
|
||
'cost_prompt' : str(cost_prompt), # str
|
||
'cost_title' : str(cost_title), # str
|
||
'cost_total' : str(cost_prompt + cost_title), # str
|
||
'time_create' : str(start_time), # str
|
||
'used_ref_ids' : [], # list[str]
|
||
'prompt_answer' : '', # str
|
||
'status_text' : status_text,
|
||
'status' : prompt_status, # or False # bool
|
||
}
|
||
|
||
# بازگرداندن آبجکت ایجاد شده
|
||
return chat_obj, status_text
|
||
|
||
llm_answer_duration = (datetime.datetime.now() - end_retrive).total_seconds()
|
||
print(f'llm answer duration: {str(llm_answer_duration)}')
|
||
|
||
llm_answer = llm_answer.replace('【','{')
|
||
llm_answer = llm_answer.replace('】','}')
|
||
used_refrences_in_answer = await find_refrences(llm_answer)
|
||
llm_answer = await replace_refrences(llm_answer, used_refrences_in_answer)
|
||
|
||
full_prompt_duration = (datetime.datetime.now() - start_time).total_seconds()
|
||
print(f'full prompt duration: {full_prompt_duration}')
|
||
print('~'*40)
|
||
|
||
status_text ='پاسخ با موفقیت ایجاد شد'
|
||
|
||
chat_obj = {
|
||
'id' : chat_id, # str
|
||
'title' : title, # str
|
||
'user_id' : '',
|
||
'user_query' : query, # str
|
||
'model_key' : llm_model, # str
|
||
'retrived_passage' : result_passages_text, # str
|
||
'retrived_ref_ids' : result_passages_ids, # list[obj]
|
||
'prompt_type' : 'question-answer', # str
|
||
'retrived_duration' : retrive_duration, # str
|
||
'llm_duration' : llm_answer_duration, # str
|
||
'full_duration' : full_prompt_duration, # str
|
||
'cost_prompt' : str(cost_prompt), # str
|
||
'cost_title' : str(cost_title), # str
|
||
'cost_total' : str(cost_prompt + cost_title), # str
|
||
'time_create' : str(start_time), # str
|
||
'used_ref_ids' : used_refrences_in_answer, # list[str]
|
||
'prompt_answer' : llm_answer, # str
|
||
'status_text' : status_text, # str
|
||
'status' : True, # or False # bool
|
||
}
|
||
|
||
prev_chat_data = []
|
||
number = 1
|
||
try:
|
||
async with aiofiles.open(path_log+f'chat-messages{number}.json', mode='r', encoding='utf-8') as file:
|
||
content = await file.read()
|
||
prev_chat_data = json.loads(content)
|
||
except:
|
||
number += 1
|
||
|
||
prev_chat_data.append(chat_obj)
|
||
async with aiofiles. open(path_log+f'chat-messages{number}.json', mode='w', encoding='utf-8') as output:
|
||
await output.write(json.dumps(prev_chat_data, ensure_ascii=False, indent=2))
|
||
|
||
# async with aiofiles. open(path_log+f'chat-messages-answer{number}.txt', mode='a+', encoding='utf-8') as output:
|
||
# await output.write(f'{chat_obj}\n+++++++++++++++++++++++++++\n')
|
||
|
||
|
||
full_prompt_duration = (datetime.datetime.now() - start_time).total_seconds()
|
||
print(f'aiofiles duration: {full_prompt_duration}')
|
||
print('~'*40)
|
||
|
||
# save_result(chat_obj)
|
||
|
||
# ایجاد آبجکت بازگشتی به فرانت
|
||
# chat_obj.pop('retrived_passage')
|
||
# chat_obj.pop('prompt_type')
|
||
|
||
print('~'*40)
|
||
|
||
return chat_obj
|
||
|
||
async def process_item(messages, reasoning_effort= 'medium', temperature= 0.4, top_p= 0.9, max_tokens= 2048):
|
||
"""
|
||
# item structure:
|
||
# item = {
|
||
# 'id' : '',
|
||
# 'system_prompt' : '',
|
||
# 'user_prompt' : '',
|
||
# 'assistant_prompt' : '',
|
||
# }
|
||
"""
|
||
try:
|
||
async with AsyncOpenAI(base_url= LLM_URL, api_key="EMPTY") as client:
|
||
|
||
model_name = 'gpt-oss-120b'
|
||
|
||
# messages = [
|
||
# {"role": "system", "content": prompt_params.get("system_prompt", "")},
|
||
# {"role": "user", "content": prompt_params.get("user_prompt", "")},
|
||
# ]
|
||
# if prompt_params.get("assistant_prompt"):
|
||
# messages.append(
|
||
# {"role": "assistant", "content": prompt_params["assistant_prompt"]}
|
||
# )
|
||
# print(f'==== max_token {max_token}')
|
||
|
||
response = await client.chat.completions.parse(
|
||
model= model_name,
|
||
messages= messages,
|
||
temperature= temperature, # 0-1
|
||
top_p=top_p, # 0-1
|
||
reasoning_effort= reasoning_effort, # low , high , medium
|
||
# max_tokens= max_tokens, # ... 128K
|
||
stop= None,
|
||
)
|
||
|
||
# print('666666666666666666666666666666666')
|
||
# print(f"response.choices[0].message.parsed: {response.choices[0].message.parsed}")
|
||
# print('666666666666666666666666666666666')
|
||
|
||
if response and response.choices : # and response.choices[0].message.parsed:
|
||
response_message = response.choices[0].message.content
|
||
return True, response_message
|
||
|
||
except Exception as e:
|
||
response_message = 'error in llm response generation!'
|
||
print('!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||
print(e)
|
||
print('!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||
return False, response_message
|
||
|
||
async def credit_refresh():
|
||
"""
|
||
Returns remained credit
|
||
"""
|
||
url = "https://api.avalai.ir/user/credit"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {await get_key()}"
|
||
}
|
||
remained_credit = requests.get(url, headers=headers)
|
||
remained_credit_value = str(remained_credit.json()['remaining_irt'])
|
||
|
||
async with aiofiles.open(path_log+'credit.txt', mode='a+', encoding='utf-8') as file:
|
||
await file.write(f'{remained_credit_value}\n')
|
||
|
||
return remained_credit_value
|
||
|
||
# تعریف مدل دادهها برای درخواستهای API
|
||
class Query(BaseModel):
|
||
query: str
|
||
|
||
|
||
initial_model()
|
||
# uvicorn src.app:app --reload
|
||
|
||
if __name__ == "__main__":
|
||
|
||
# query = 'در قانون حمایت از خانواده و جوانی جمعیت چه خدماتی در نظر گرفته شده است؟'
|
||
while True:
|
||
query = input('enter your qustion:')
|
||
if query == '':
|
||
print('لطفا متن سوال را وارد نمائید')
|
||
continue
|
||
start = (datetime.datetime.now())
|
||
# result = test_dataset()
|
||
result = search_query(query)
|
||
end_retrive = datetime.datetime.now()
|
||
print('-'*40)
|
||
print(f'retrive duration: {(end_retrive - start).total_seconds()}')
|
||
|
||
prompt = f'برای پرسش "{query}" از میان مواد قانونی "{result}" .پاسخ مناسب و دقیق را استخراج کن. درصورتی که مطلبی مرتبط با پرسش در متن پیدا نشد، فقط پاسخ بده: "متاسفانه در منابع، پاسخی پیدا نشد!"'
|
||
llm_answer = llm_request(prompt)
|
||
|
||
print('-'*40)
|
||
print(f'llm duration: {(datetime.datetime.now() - end_retrive).total_seconds()}')
|
||
|
||
refrences = ''
|
||
recognized_refrences = find_refrences(llm_answer)
|
||
llm_answer = replace_refrences(llm_answer, recognized_refrences)
|
||
|
||
print('-'*40)
|
||
print(f'replace_refrences duration: {(datetime.datetime.now() - end_retrive).total_seconds()}')
|
||
|
||
with open(path_log+'result.txt', mode='a+', encoding='utf-8') as file:
|
||
result_message = f'متن پرامپت: {query.strip()}\n\nپاسخ: {llm_answer} \n----------------------------------------------------------\n'
|
||
file.write(result_message)
|
||
|
||
with open(path_log+'passages.txt', mode='a+', encoding='utf-8') as file:
|
||
result_message = f'متن پرامپت: {query.strip()}\n\مواد مشابه: {result} \n----------------------------------------------------------\n'
|
||
file.write(result_message)
|
||
|
||
print('-'*40)
|
||
print(f'file write duration: {(datetime.datetime.now() - end_retrive).total_seconds()}')
|
||
|
||
|
||
print('----------------------------------------------------------')
|
||
print(f'full duration: {(datetime.datetime.now() - start).total_seconds()}')
|
||
print('----------------------------------------------------------')
|
||
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
|
||
|
||
|