188 lines
6.7 KiB
Python
188 lines
6.7 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
Hybrid Retrieval + Reranker Pipeline (Debuggable Version)
|
||
---------------------------------------------------------
|
||
این نسخه برای اجرا و دیباگ خط به خط (pdb یا IDE) آماده شده است.
|
||
- Dense retriever: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
|
||
- Sparse retriever: TF-IDF
|
||
- Fusion: weighted sum
|
||
- Reranker: BAAI/bge-reranker-v2-m3
|
||
|
||
نحوه اجرا در حالت دیباگر:
|
||
python -m pdb hybrid_retrieval_reranker_debug.py
|
||
"""
|
||
|
||
import json
|
||
import numpy as np
|
||
import torch
|
||
from typing import List, Tuple, Dict
|
||
from sentence_transformers import SentenceTransformer, util
|
||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
||
# -------------------
|
||
# مدلها و مسیر داده
|
||
# -------------------
|
||
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
||
DATA_PATH = "./output/sentences_vector.json"
|
||
|
||
|
||
def load_dataset(path: str) -> Tuple[List[str], np.ndarray]:
|
||
"""Load sentences and embeddings (float32)."""
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
raw = json.load(f)
|
||
|
||
# اگر فایل dict باشه → به list تبدیل میکنیم
|
||
if isinstance(raw, dict):
|
||
raw = list(raw.values())
|
||
|
||
sentences, emb_list = [], []
|
||
for it in raw:
|
||
sent = it.get("sentence")
|
||
emb = it.get("embeddings")
|
||
if sent and isinstance(emb, (list, tuple)):
|
||
sentences.append(sent)
|
||
emb_list.append(emb)
|
||
|
||
if not sentences:
|
||
raise ValueError("Dataset invalid. Needs 'sentence' + 'embeddings'.")
|
||
|
||
emb_matrix = np.asarray(emb_list, dtype=np.float32)
|
||
return sentences, emb_matrix
|
||
|
||
|
||
class HybridRetrieverReranker:
|
||
def __init__(self, sentences: List[str], emb_matrix: np.ndarray,
|
||
dense_alpha: float = 0.6, device: str = None):
|
||
|
||
if device is None:
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
self.device = device
|
||
|
||
self.sentences = sentences
|
||
self.emb_matrix = emb_matrix
|
||
self.N = len(sentences)
|
||
|
||
# Dense
|
||
self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device)
|
||
self.embeddings_tensor = torch.from_numpy(self.emb_matrix).to(self.device)
|
||
|
||
# Sparse
|
||
self.vectorizer = TfidfVectorizer(
|
||
analyzer="word",
|
||
ngram_range=(1, 2),
|
||
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b"
|
||
)
|
||
self.tfidf_matrix = self.vectorizer.fit_transform(self.sentences)
|
||
|
||
# Reranker
|
||
self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL)
|
||
self.reranker = AutoModelForSequenceClassification.from_pretrained(
|
||
RERANKER_MODEL
|
||
).to(self.device)
|
||
|
||
self.dense_alpha = float(dense_alpha)
|
||
|
||
def dense_retrieve(self, query: str, top_k: int):
|
||
q_emb = self.embedder.encode(query, convert_to_tensor=True).to(self.device)
|
||
similars = util.cos_sim(q_emb, self.embeddings_tensor).squeeze(0)
|
||
top_scores, top_idx = torch.topk(similars, k=min(top_k, self.N))
|
||
return top_idx.tolist(), top_scores.detach().cpu().numpy()
|
||
|
||
def sparse_retrieve(self, query: str, top_k: int):
|
||
q_vec = self.vectorizer.transform([query])
|
||
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
|
||
idx = np.argpartition(-sims, kth=top_k-1)[:top_k]
|
||
idx = idx[np.argsort(-sims[idx])]
|
||
return idx.tolist(), sims[idx]
|
||
|
||
@staticmethod
|
||
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
|
||
if arr.size == 0:
|
||
return arr
|
||
a_min, a_max = arr.min(), arr.max()
|
||
if a_max - a_min < 1e-12:
|
||
return np.zeros_like(arr)
|
||
return (arr - a_min) / (a_max - a_min)
|
||
|
||
def fuse(self, d_idx, d_scores, s_idx, s_scores, pre_rerank_k: int):
|
||
d_norm = self._minmax_norm(d_scores.astype(np.float32))
|
||
s_norm = self._minmax_norm(s_scores.astype(np.float32))
|
||
|
||
d_map = {i: d for i, d in zip(d_idx, d_norm)}
|
||
s_map = {i: s for i, s in zip(s_idx, s_norm)}
|
||
|
||
fused = []
|
||
for i in set(d_idx) | set(s_idx):
|
||
score = self.dense_alpha * d_map.get(i, 0.0) + (1-self.dense_alpha) * s_map.get(i, 0.0)
|
||
fused.append((i, score))
|
||
|
||
fused.sort(key=lambda x: x[1], reverse=True)
|
||
return [i for i, _ in fused[:pre_rerank_k]]
|
||
|
||
def rerank(self, query: str, candidate_indices: List[int], final_k: int):
|
||
pairs = [(query, self.sentences[i]) for i in candidate_indices]
|
||
scores = []
|
||
for batch in [pairs[i:i+16] for i in range(0, len(pairs), 16)]:
|
||
inputs = self.tokenizer(batch, padding=True, truncation=True,
|
||
max_length=512, return_tensors="pt").to(self.device)
|
||
with torch.no_grad():
|
||
logits = self.reranker(**inputs).logits.view(-1)
|
||
scores.extend(logits.cpu().tolist())
|
||
|
||
items = sorted(zip(candidate_indices, scores), key=lambda x: x[1], reverse=True)
|
||
return items[:final_k]
|
||
|
||
def search(self, query: str, topk_dense=50, topk_sparse=50,
|
||
pre_rerank_k=50, final_k=5):
|
||
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
|
||
# import pdb; pdb.set_trace() # ← میتونی اینجا توقف کنی
|
||
|
||
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
|
||
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
|
||
# import pdb; pdb.set_trace() # ← بعد از fusion توقف
|
||
|
||
reranked = self.rerank(query, cand_idx, final_k)
|
||
# import pdb; pdb.set_trace() # ← بعد از rerank توقف
|
||
|
||
return [{"idx": i, "sentence": self.sentences[i], "rerank_score": score}
|
||
for i, score in reranked]
|
||
|
||
|
||
def main():
|
||
query = "انسان در فتنه ها باید چگونه عملی کند؟"
|
||
sentences, emb_matrix = load_dataset(DATA_PATH)
|
||
|
||
pipe = HybridRetrieverReranker(sentences, emb_matrix, dense_alpha=0.6)
|
||
results = pipe.search(query, topk_dense=30, topk_sparse=30, pre_rerank_k=30, final_k=5)
|
||
|
||
print("\nTop results:")
|
||
for i, r in enumerate(results, 1):
|
||
print(f"{i}. [score={r['rerank_score']:.4f}] {r['sentence']}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import datetime
|
||
start = datetime.datetime.now()
|
||
main()
|
||
time2 = datetime.datetime.now()
|
||
print(time2 - start)
|
||
|
||
main()
|
||
time3 = datetime.datetime.now()
|
||
print(time3 - time2)
|
||
|
||
main()
|
||
time4 = datetime.datetime.now()
|
||
print(time4 - time3)
|
||
|
||
main()
|
||
time5 = datetime.datetime.now()
|
||
print(time5 - time4)
|
||
|
||
pass
|