338 lines
10 KiB
Python
Executable File
338 lines
10 KiB
Python
Executable File
import numpy as np
|
||
import torch, orjson, faiss, re
|
||
from typing import List
|
||
from sentence_transformers import SentenceTransformer
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
from FlagEmbedding import FlagReranker
|
||
from pathlib import Path
|
||
|
||
# nlist = 2048
|
||
# quantizer = faiss.IndexFlatIP(dim)
|
||
# index = faiss.IndexIVFFlat(quantizer, dim, nlist)
|
||
# index.train(embeddings)
|
||
# index.add(embeddings)
|
||
|
||
class InitHybridRetrieverReranker:
|
||
def __init__(
|
||
self,
|
||
embeder_path,
|
||
reranker_path,
|
||
dict_content: List[dict],
|
||
faiss_index,
|
||
dense_alpha: float = 0.6,
|
||
device: str = None,
|
||
cache_dir="/src/MODELS",
|
||
batch_size=512,
|
||
|
||
):
|
||
|
||
if device is None:
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
self.device = device
|
||
self.dense_alpha = dense_alpha
|
||
|
||
# ===============================
|
||
# تبدیل ورودی فقط یک بار
|
||
# ===============================
|
||
self.content_list = [x["content"] for x in dict_content]
|
||
self.ids_list = [x["id"] for x in dict_content]
|
||
self.N = len(self.content_list)
|
||
self.faiss_index = faiss_index
|
||
|
||
# Dense embedder
|
||
self.embedder = SentenceTransformer(
|
||
local_files_only=True,
|
||
model_name_or_path=embeder_path,
|
||
cache_folder=cache_dir,
|
||
device=self.device,
|
||
similarity_fn_name="cosine",
|
||
)
|
||
|
||
# TF-IDF
|
||
self.vectorizer = TfidfVectorizer(
|
||
analyzer="word",
|
||
ngram_range=(1, 2),
|
||
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
|
||
)
|
||
self.tfidf_matrix = self.vectorizer.fit_transform(self.content_list)
|
||
|
||
# Reranker
|
||
self.reranker = FlagReranker(
|
||
model_name_or_path=reranker_path,
|
||
local_files_only=True,
|
||
use_fp16=True,
|
||
devices=device,
|
||
cache_dir=cache_dir,
|
||
batch_size=batch_size,
|
||
normalize=True,
|
||
# max_length=1024,
|
||
# trust_remote_code=False,
|
||
# query_max_length=
|
||
)
|
||
|
||
|
||
print("RAG Ready — Retriever + Reranker Loaded")
|
||
|
||
# ================================
|
||
# Dense Search (FAISS)
|
||
# ================================
|
||
async def dense_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([])
|
||
|
||
emb = self.embedder.encode(query, convert_to_numpy=True).astype(np.float32)
|
||
|
||
D, I = self.faiss_index.search(emb.reshape(1, -1), top_k)
|
||
return I[0], D[0]
|
||
|
||
# ================================
|
||
# Sparse Search (TF-IDF)
|
||
# ================================
|
||
async def sparse_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([])
|
||
|
||
q_vec = self.vectorizer.transform([query])
|
||
sims = cosine_similarity(q_vec, self.tfidf_matrix)[0]
|
||
|
||
k = min(top_k, len(sims))
|
||
idx = np.argpartition(-sims, k - 1)[:k]
|
||
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
|
||
|
||
return idx, sims[idx]
|
||
|
||
# ================================
|
||
# Reciprocal Rank Fusion
|
||
# ================================
|
||
async def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
|
||
combined = {}
|
||
|
||
for rank, idx in enumerate(d_idx):
|
||
combined[idx] = combined.get(idx, 0) + 1.0 / (k_rrf + rank)
|
||
|
||
for rank, idx in enumerate(s_idx):
|
||
combined[idx] = combined.get(idx, 0) + 1.0 / (k_rrf + rank)
|
||
|
||
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
|
||
|
||
return [i[0] for i in sorted_items[:top_k]]
|
||
|
||
# ================================
|
||
# Rerank
|
||
# ================================
|
||
async def rerank(self, query: str, cand_idx: List[int], final_k: int = 10):
|
||
if not cand_idx:
|
||
return []
|
||
|
||
passages = [self.content_list[i] for i in cand_idx]
|
||
pairs = [[query, p] for p in passages]
|
||
|
||
scores = self.reranker.compute_score(pairs, normalize=True, max_length=512)
|
||
|
||
if isinstance(scores, float):
|
||
scores = [scores]
|
||
|
||
idx_score = list(zip(cand_idx, scores))
|
||
idx_score.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
return idx_score[:final_k]
|
||
|
||
# ================================
|
||
# Main Search Function
|
||
# ================================
|
||
async def search_base(
|
||
self,
|
||
query: str,
|
||
topk_dense=50,
|
||
topk_sparse=50,
|
||
pre_rerank_k=50,
|
||
final_k=10,
|
||
):
|
||
|
||
d_idx, d_scores = await self.dense_retrieve(query, topk_dense)
|
||
s_idx, s_scores = await self.sparse_retrieve(query, topk_sparse)
|
||
|
||
cand_idx = await self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
|
||
final_rank = await self.rerank(query, cand_idx, final_k)
|
||
|
||
# ===============================
|
||
# خروجی سریع و تمیز
|
||
# ===============================
|
||
return [
|
||
{
|
||
"id": self.ids_list[idx],
|
||
"content": self.content_list[idx],
|
||
"score": score,
|
||
}
|
||
for idx, score in final_rank
|
||
]
|
||
|
||
|
||
def load_orjson(path: str | Path):
|
||
path = Path(path)
|
||
with path.open("rb") as f: # باید باینری باز بشه برای orjson
|
||
return orjson.loads(f.read())
|
||
|
||
|
||
def save_orjson(path, data):
|
||
with open(path, "wb") as f:
|
||
f.write(
|
||
orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS)
|
||
)
|
||
|
||
WEB_LINK = "https://majles.tavasi.ir/entity/detail/view/qsection/"
|
||
# ref = f"[«{i}»](https://majles.tavasi.ir/entity/detail/view/qsection/{idx})"
|
||
|
||
def get_in_form(title: str, sections: list, max_len: int = 4000):
|
||
chunks = []
|
||
current = f"برای پرسش: {title}\n\n"
|
||
ref_text = "«منبع»"
|
||
|
||
for i, data in enumerate(sections, start=1):
|
||
sec_text = data.get("content", "")
|
||
idx = data.get("id")
|
||
|
||
# ساخت ref کامل
|
||
ref = f"[{ref_text}]({WEB_LINK}{idx})"
|
||
# متن کامل آیتم
|
||
block = f"{i}: {sec_text}\n{ref}\n\n"
|
||
|
||
# اگر با اضافه شدن این آیتم از حد مجاز عبور میکنیم → شروع چانک جدید
|
||
if len(current) + len(block) > max_len:
|
||
chunks.append(current.rstrip())
|
||
current = ""
|
||
|
||
current += block
|
||
|
||
# آخرین چانک را هم اضافه کن
|
||
if current.strip():
|
||
chunks.append(current.rstrip())
|
||
|
||
return chunks
|
||
|
||
|
||
|
||
def format_answer_bale(answer_text: str, sources: list, max_len: int = 4000):
|
||
"""
|
||
answer_text: متن خروجی مدل که داخلش عبارتهای مثل (منبع: qs2117427) وجود دارد
|
||
sources: مثل ['qs2117427']
|
||
"""
|
||
ref_text = "«منبع»"
|
||
|
||
def make_link(src):
|
||
return f"[{ref_text}]({WEB_LINK}{src})"
|
||
|
||
|
||
# الگو برای تشخیص هر پرانتز که شامل یک یا چند کد باشد
|
||
# مثلا: (qs123) یا (qs123, qs456, qs789)
|
||
pattern = r"\((?:منبع[:: ]+)?([a-zA-Z0-9_, ]+)\)"
|
||
|
||
def replace_source(m):
|
||
content = m.group(1)
|
||
codes = [c.strip() for c in content.split(",")] # جداسازی چند کد
|
||
links = [make_link(code) for code in codes]
|
||
full_match = m.group(0)
|
||
# if "منبع" in full_match:
|
||
# print(f'Found explicit source(s): {links}')
|
||
# else:
|
||
# print(f'Found implicit source(s): {links}')
|
||
return ", ".join(links) # جایگزینی همه کدها با لینکهایشان
|
||
|
||
# جایگزینی در متن
|
||
answer_text = re.sub(pattern, replace_source, answer_text)
|
||
|
||
# اگر طول کمتر از max_len بود → تمام
|
||
if len(answer_text) <= max_len:
|
||
return [answer_text]
|
||
|
||
# تقسیم متن اگر طول زیاد شد
|
||
chunks = []
|
||
current = ""
|
||
|
||
sentences = answer_text.split(". ")
|
||
for sentence in sentences:
|
||
st = sentence.strip()
|
||
if not st.endswith("."):
|
||
st += "."
|
||
|
||
if len(current) + len(st) > max_len:
|
||
chunks.append(current.strip())
|
||
current = ""
|
||
|
||
current += st + " "
|
||
|
||
if current.strip():
|
||
chunks.append(current.strip())
|
||
|
||
return chunks
|
||
|
||
|
||
|
||
|
||
def get_user_prompt(query: str):
|
||
"""
|
||
get a query and prepare a prompt to generate title based on that
|
||
"""
|
||
title_prompt = f"برای متن {query} یک عنوان با معنا که بین 3 تا 6 کلمه داشته باشد، در قالب یک رشته متن ایجاد کن. سبک و لحن عنوان، حقوقی و کاملا رسمی باشد. عنوان تولید شده کاملا ساده و بدون هیچ مارک داون یا علائم افزوده ای باشد. غیر از عنوان، به هیچ وجه توضیح اضافه ای در قبل یا بعد آن اضافه نکن."
|
||
return title_prompt
|
||
|
||
|
||
def format_knowledge_block(knowledge):
|
||
lines = []
|
||
for item in knowledge:
|
||
_id = item.get("id", "unknown")
|
||
_content = item.get("content", "")
|
||
lines.append(f"- ({_id}) { _content }")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def get_user_prompt2(obj):
|
||
query = obj.query
|
||
knowledge = obj.knowledge
|
||
|
||
prompt = f"""
|
||
شما باید تنها بر اساس اطلاعات ارائه شده پاسخ بدهید و هیچ دانشی خارج از آنها استفاده نکنید.
|
||
|
||
### پرسش:
|
||
{query}
|
||
|
||
### اسناد قابل استناد:
|
||
{format_knowledge_block(knowledge)}
|
||
|
||
### دستور تولید خروجی:
|
||
- پاسخی کاملاً دقیق، تحلیلی و مفهومی ایجاد کن
|
||
- لحن رسمی و حقوقی باشد
|
||
- اگر پاسخ نیاز به ترکیب چند سند دارد، آنها را ادغام کن
|
||
- اگر دادهها کافی نبود، این موضوع را شفاف اعلام کن اما اطلاعات مرتبط را همچنان ارائه بده
|
||
"""
|
||
return prompt
|
||
|
||
|
||
def get_user_prompt3(query, knowledge_json):
|
||
sys = f"""Answer the following based ONLY on the knowledge:
|
||
|
||
Query:
|
||
{query}
|
||
|
||
Knowledge:
|
||
{knowledge_json}"""
|
||
return sys
|
||
|
||
def load_faiss_index(index_path: str, metadata_path: str):
|
||
"""بارگذاری ایندکس FAISS و متادیتا (لیست جملات + عناوین)."""
|
||
index = faiss.read_index(index_path)
|
||
|
||
metadata = load_orjson(metadata_path)
|
||
|
||
metadata = [
|
||
{
|
||
"id": item["id"],
|
||
"content": item["content"],
|
||
"prefix": item["prefix"],
|
||
}
|
||
for item in metadata
|
||
]
|
||
return metadata, index
|