import numpy as np import torch, orjson, faiss, re from typing import List from sentence_transformers import SentenceTransformer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from FlagEmbedding import FlagReranker from pathlib import Path # nlist = 2048 # quantizer = faiss.IndexFlatIP(dim) # index = faiss.IndexIVFFlat(quantizer, dim, nlist) # index.train(embeddings) # index.add(embeddings) class InitHybridRetrieverReranker: def __init__( self, embeder_path, reranker_path, dict_content: List[dict], faiss_index, dense_alpha: float = 0.6, device: str = None, cache_dir="/src/MODELS", batch_size=512, ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.dense_alpha = dense_alpha # =============================== # تبدیل ورودی فقط یک بار # =============================== self.content_list = [x["content"] for x in dict_content] self.ids_list = [x["id"] for x in dict_content] self.N = len(self.content_list) self.faiss_index = faiss_index # Dense embedder self.embedder = SentenceTransformer( local_files_only=True, model_name_or_path=embeder_path, cache_folder=cache_dir, device=self.device, similarity_fn_name="cosine", ) # TF-IDF self.vectorizer = TfidfVectorizer( analyzer="word", ngram_range=(1, 2), token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b", ) self.tfidf_matrix = self.vectorizer.fit_transform(self.content_list) # Reranker self.reranker = FlagReranker( model_name_or_path=reranker_path, local_files_only=True, use_fp16=True, devices=device, cache_dir=cache_dir, batch_size=batch_size, normalize=True, # max_length=1024, # trust_remote_code=False, # query_max_length= ) print("RAG Ready — Retriever + Reranker Loaded") # ================================ # Dense Search (FAISS) # ================================ async def dense_retrieve(self, query: str, top_k: int): if top_k <= 0: return [], np.array([]) emb = self.embedder.encode(query, convert_to_numpy=True).astype(np.float32) D, I = self.faiss_index.search(emb.reshape(1, -1), top_k) return I[0], D[0] # ================================ # Sparse Search (TF-IDF) # ================================ async def sparse_retrieve(self, query: str, top_k: int): if top_k <= 0: return [], np.array([]) q_vec = self.vectorizer.transform([query]) sims = cosine_similarity(q_vec, self.tfidf_matrix)[0] k = min(top_k, len(sims)) idx = np.argpartition(-sims, k - 1)[:k] idx = idx[np.argsort(-sims[idx], kind="mergesort")] return idx, sims[idx] # ================================ # Reciprocal Rank Fusion # ================================ async def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60): combined = {} for rank, idx in enumerate(d_idx): combined[idx] = combined.get(idx, 0) + 1.0 / (k_rrf + rank) for rank, idx in enumerate(s_idx): combined[idx] = combined.get(idx, 0) + 1.0 / (k_rrf + rank) sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True) return [i[0] for i in sorted_items[:top_k]] # ================================ # Rerank # ================================ async def rerank(self, query: str, cand_idx: List[int], final_k: int = 10): if not cand_idx: return [] passages = [self.content_list[i] for i in cand_idx] pairs = [[query, p] for p in passages] scores = self.reranker.compute_score(pairs, normalize=True, max_length=512) if isinstance(scores, float): scores = [scores] idx_score = list(zip(cand_idx, scores)) idx_score.sort(key=lambda x: x[1], reverse=True) return idx_score[:final_k] # ================================ # Main Search Function # ================================ async def search_base( self, query: str, topk_dense=50, topk_sparse=50, pre_rerank_k=50, final_k=10, ): d_idx, d_scores = await self.dense_retrieve(query, topk_dense) s_idx, s_scores = await self.sparse_retrieve(query, topk_sparse) cand_idx = await self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k) final_rank = await self.rerank(query, cand_idx, final_k) # =============================== # خروجی سریع و تمیز # =============================== return [ { "id": self.ids_list[idx], "content": self.content_list[idx], "score": score, } for idx, score in final_rank ] def load_orjson(path: str | Path): path = Path(path) with path.open("rb") as f: # باید باینری باز بشه برای orjson return orjson.loads(f.read()) def save_orjson(path, data): with open(path, "wb") as f: f.write( orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS) ) WEB_LINK = "https://majles.tavasi.ir/entity/detail/view/qsection/" # ref = f"[«{i}»](https://majles.tavasi.ir/entity/detail/view/qsection/{idx})" def get_in_form(title: str, sections: list, max_len: int = 4000): chunks = [] current = f"برای پرسش: {title}\n\n" ref_text = "«منبع»" for i, data in enumerate(sections, start=1): sec_text = data.get("content", "") idx = data.get("id") # ساخت ref کامل ref = f"[{ref_text}]({WEB_LINK}{idx})" # متن کامل آیتم block = f"{i}: {sec_text}\n{ref}\n\n" # اگر با اضافه شدن این آیتم از حد مجاز عبور می‌کنیم → شروع چانک جدید if len(current) + len(block) > max_len: chunks.append(current.rstrip()) current = "" current += block # آخرین چانک را هم اضافه کن if current.strip(): chunks.append(current.rstrip()) return chunks def format_answer_bale(answer_text: str, sources: list, max_len: int = 4000): """ answer_text: متن خروجی مدل که داخلش عبارت‌های مثل (منبع: qs2117427) وجود دارد sources: مثل ['qs2117427'] """ ref_text = "«منبع»" def make_link(src): return f"[{ref_text}]({WEB_LINK}{src})" # الگو برای تشخیص هر پرانتز که شامل یک یا چند کد باشد # مثلا: (qs123) یا (qs123, qs456, qs789) pattern = r"\((?:منبع[:: ]+)?([a-zA-Z0-9_, ]+)\)" def replace_source(m): content = m.group(1) codes = [c.strip() for c in content.split(",")] # جداسازی چند کد links = [make_link(code) for code in codes] full_match = m.group(0) # if "منبع" in full_match: # print(f'Found explicit source(s): {links}') # else: # print(f'Found implicit source(s): {links}') return ", ".join(links) # جایگزینی همه کدها با لینک‌هایشان # جایگزینی در متن answer_text = re.sub(pattern, replace_source, answer_text) # اگر طول کمتر از max_len بود → تمام if len(answer_text) <= max_len: return [answer_text] # تقسیم متن اگر طول زیاد شد chunks = [] current = "" sentences = answer_text.split(". ") for sentence in sentences: st = sentence.strip() if not st.endswith("."): st += "." if len(current) + len(st) > max_len: chunks.append(current.strip()) current = "" current += st + " " if current.strip(): chunks.append(current.strip()) return chunks def get_user_prompt(query: str): """ get a query and prepare a prompt to generate title based on that """ title_prompt = f"برای متن {query} یک عنوان با معنا که بین 3 تا 6 کلمه داشته باشد، در قالب یک رشته متن ایجاد کن. سبک و لحن عنوان، حقوقی و کاملا رسمی باشد. عنوان تولید شده کاملا ساده و بدون هیچ مارک داون یا علائم افزوده ای باشد. غیر از عنوان، به هیچ وجه توضیح اضافه ای در قبل یا بعد آن اضافه نکن." return title_prompt def format_knowledge_block(knowledge): lines = [] for item in knowledge: _id = item.get("id", "unknown") _content = item.get("content", "") lines.append(f"- ({_id}) { _content }") return "\n".join(lines) def get_user_prompt2(obj): query = obj.query knowledge = obj.knowledge prompt = f""" شما باید تنها بر اساس اطلاعات ارائه شده پاسخ بدهید و هیچ دانشی خارج از آنها استفاده نکنید. ### پرسش: {query} ### اسناد قابل استناد: {format_knowledge_block(knowledge)} ### دستور تولید خروجی: - پاسخی کاملاً دقیق، تحلیلی و مفهومی ایجاد کن - لحن رسمی و حقوقی باشد - اگر پاسخ نیاز به ترکیب چند سند دارد، آنها را ادغام کن - اگر داده‌ها کافی نبود، این موضوع را شفاف اعلام کن اما اطلاعات مرتبط را همچنان ارائه بده """ return prompt def get_user_prompt3(query, knowledge_json): sys = f"""Answer the following based ONLY on the knowledge: Query: {query} Knowledge: {knowledge_json}""" return sys def load_faiss_index(index_path: str, metadata_path: str): """بارگذاری ایندکس FAISS و متادیتا (لیست جملات + عناوین).""" index = faiss.read_index(index_path) metadata = load_orjson(metadata_path) metadata = [ { "id": item["id"], "content": item["content"], "prefix": item["prefix"], } for item in metadata ] return metadata, index