rag_qavanin_api/routers/chatbot_handler.py
2025-11-29 20:11:27 +00:00

338 lines
10 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch, orjson, faiss, re
from typing import List
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from FlagEmbedding import FlagReranker
from pathlib import Path
# nlist = 2048
# quantizer = faiss.IndexFlatIP(dim)
# index = faiss.IndexIVFFlat(quantizer, dim, nlist)
# index.train(embeddings)
# index.add(embeddings)
class InitHybridRetrieverReranker:
def __init__(
self,
embeder_path,
reranker_path,
dict_content: List[dict],
faiss_index,
dense_alpha: float = 0.6,
device: str = None,
cache_dir="/src/MODELS",
batch_size=512,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dense_alpha = dense_alpha
# ===============================
# تبدیل ورودی فقط یک بار
# ===============================
self.content_list = [x["content"] for x in dict_content]
self.ids_list = [x["id"] for x in dict_content]
self.N = len(self.content_list)
self.faiss_index = faiss_index
# Dense embedder
self.embedder = SentenceTransformer(
local_files_only=True,
model_name_or_path=embeder_path,
cache_folder=cache_dir,
device=self.device,
similarity_fn_name="cosine",
)
# TF-IDF
self.vectorizer = TfidfVectorizer(
analyzer="word",
ngram_range=(1, 2),
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
)
self.tfidf_matrix = self.vectorizer.fit_transform(self.content_list)
# Reranker
self.reranker = FlagReranker(
model_name_or_path=reranker_path,
local_files_only=True,
use_fp16=True,
devices=device,
cache_dir=cache_dir,
batch_size=batch_size,
normalize=True,
# max_length=1024,
# trust_remote_code=False,
# query_max_length=
)
print("RAG Ready — Retriever + Reranker Loaded")
# ================================
# Dense Search (FAISS)
# ================================
async def dense_retrieve(self, query: str, top_k: int):
if top_k <= 0:
return [], np.array([])
emb = self.embedder.encode(query, convert_to_numpy=True).astype(np.float32)
D, I = self.faiss_index.search(emb.reshape(1, -1), top_k)
return I[0], D[0]
# ================================
# Sparse Search (TF-IDF)
# ================================
async def sparse_retrieve(self, query: str, top_k: int):
if top_k <= 0:
return [], np.array([])
q_vec = self.vectorizer.transform([query])
sims = cosine_similarity(q_vec, self.tfidf_matrix)[0]
k = min(top_k, len(sims))
idx = np.argpartition(-sims, k - 1)[:k]
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
return idx, sims[idx]
# ================================
# Reciprocal Rank Fusion
# ================================
async def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
combined = {}
for rank, idx in enumerate(d_idx):
combined[idx] = combined.get(idx, 0) + 1.0 / (k_rrf + rank)
for rank, idx in enumerate(s_idx):
combined[idx] = combined.get(idx, 0) + 1.0 / (k_rrf + rank)
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
return [i[0] for i in sorted_items[:top_k]]
# ================================
# Rerank
# ================================
async def rerank(self, query: str, cand_idx: List[int], final_k: int = 10):
if not cand_idx:
return []
passages = [self.content_list[i] for i in cand_idx]
pairs = [[query, p] for p in passages]
scores = self.reranker.compute_score(pairs, normalize=True, max_length=512)
if isinstance(scores, float):
scores = [scores]
idx_score = list(zip(cand_idx, scores))
idx_score.sort(key=lambda x: x[1], reverse=True)
return idx_score[:final_k]
# ================================
# Main Search Function
# ================================
async def search_base(
self,
query: str,
topk_dense=50,
topk_sparse=50,
pre_rerank_k=50,
final_k=10,
):
d_idx, d_scores = await self.dense_retrieve(query, topk_dense)
s_idx, s_scores = await self.sparse_retrieve(query, topk_sparse)
cand_idx = await self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
final_rank = await self.rerank(query, cand_idx, final_k)
# ===============================
# خروجی سریع و تمیز
# ===============================
return [
{
"id": self.ids_list[idx],
"content": self.content_list[idx],
"score": score,
}
for idx, score in final_rank
]
def load_orjson(path: str | Path):
path = Path(path)
with path.open("rb") as f: # باید باینری باز بشه برای orjson
return orjson.loads(f.read())
def save_orjson(path, data):
with open(path, "wb") as f:
f.write(
orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS)
)
WEB_LINK = "https://majles.tavasi.ir/entity/detail/view/qsection/"
# ref = f"[«{i}»](https://majles.tavasi.ir/entity/detail/view/qsection/{idx})"
def get_in_form(title: str, sections: list, max_len: int = 4000):
chunks = []
current = f"برای پرسش: {title}\n\n"
ref_text = "«منبع»"
for i, data in enumerate(sections, start=1):
sec_text = data.get("content", "")
idx = data.get("id")
# ساخت ref کامل
ref = f"[{ref_text}]({WEB_LINK}{idx})"
# متن کامل آیتم
block = f"{i}: {sec_text}\n{ref}\n\n"
# اگر با اضافه شدن این آیتم از حد مجاز عبور می‌کنیم → شروع چانک جدید
if len(current) + len(block) > max_len:
chunks.append(current.rstrip())
current = ""
current += block
# آخرین چانک را هم اضافه کن
if current.strip():
chunks.append(current.rstrip())
return chunks
def format_answer_bale(answer_text: str, sources: list, max_len: int = 4000):
"""
answer_text: متن خروجی مدل که داخلش عبارت‌های مثل (منبع: qs2117427) وجود دارد
sources: مثل ['qs2117427']
"""
ref_text = "«منبع»"
def make_link(src):
return f"[{ref_text}]({WEB_LINK}{src})"
# الگو برای تشخیص هر پرانتز که شامل یک یا چند کد باشد
# مثلا: (qs123) یا (qs123, qs456, qs789)
pattern = r"\((?:منبع[: ]+)?([a-zA-Z0-9_, ]+)\)"
def replace_source(m):
content = m.group(1)
codes = [c.strip() for c in content.split(",")] # جداسازی چند کد
links = [make_link(code) for code in codes]
full_match = m.group(0)
# if "منبع" in full_match:
# print(f'Found explicit source(s): {links}')
# else:
# print(f'Found implicit source(s): {links}')
return ", ".join(links) # جایگزینی همه کدها با لینک‌هایشان
# جایگزینی در متن
answer_text = re.sub(pattern, replace_source, answer_text)
# اگر طول کمتر از max_len بود → تمام
if len(answer_text) <= max_len:
return [answer_text]
# تقسیم متن اگر طول زیاد شد
chunks = []
current = ""
sentences = answer_text.split(". ")
for sentence in sentences:
st = sentence.strip()
if not st.endswith("."):
st += "."
if len(current) + len(st) > max_len:
chunks.append(current.strip())
current = ""
current += st + " "
if current.strip():
chunks.append(current.strip())
return chunks
def get_user_prompt(query: str):
"""
get a query and prepare a prompt to generate title based on that
"""
title_prompt = f"برای متن {query} یک عنوان با معنا که بین 3 تا 6 کلمه داشته باشد، در قالب یک رشته متن ایجاد کن. سبک و لحن عنوان، حقوقی و کاملا رسمی باشد. عنوان تولید شده کاملا ساده و بدون هیچ مارک داون یا علائم افزوده ای باشد. غیر از عنوان، به هیچ وجه توضیح اضافه ای در قبل یا بعد آن اضافه نکن."
return title_prompt
def format_knowledge_block(knowledge):
lines = []
for item in knowledge:
_id = item.get("id", "unknown")
_content = item.get("content", "")
lines.append(f"- ({_id}) { _content }")
return "\n".join(lines)
def get_user_prompt2(obj):
query = obj.query
knowledge = obj.knowledge
prompt = f"""
شما باید تنها بر اساس اطلاعات ارائه شده پاسخ بدهید و هیچ دانشی خارج از آنها استفاده نکنید.
### پرسش:
{query}
### اسناد قابل استناد:
{format_knowledge_block(knowledge)}
### دستور تولید خروجی:
- پاسخی کاملاً دقیق، تحلیلی و مفهومی ایجاد کن
- لحن رسمی و حقوقی باشد
- اگر پاسخ نیاز به ترکیب چند سند دارد، آنها را ادغام کن
- اگر داده‌ها کافی نبود، این موضوع را شفاف اعلام کن اما اطلاعات مرتبط را همچنان ارائه بده
"""
return prompt
def get_user_prompt3(query, knowledge_json):
sys = f"""Answer the following based ONLY on the knowledge:
Query:
{query}
Knowledge:
{knowledge_json}"""
return sys
def load_faiss_index(index_path: str, metadata_path: str):
"""بارگذاری ایندکس FAISS و متادیتا (لیست جملات + عناوین)."""
index = faiss.read_index(index_path)
metadata = load_orjson(metadata_path)
metadata = [
{
"id": item["id"],
"content": item["content"],
"prefix": item["prefix"],
}
for item in metadata
]
return metadata, index