nahj/hybrid_retrieval_reranker.py

188 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Hybrid Retrieval + Reranker Pipeline (Debuggable Version)
---------------------------------------------------------
این نسخه برای اجرا و دیباگ خط به خط (pdb یا IDE) آماده شده است.
- Dense retriever: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
- Sparse retriever: TF-IDF
- Fusion: weighted sum
- Reranker: BAAI/bge-reranker-v2-m3
نحوه اجرا در حالت دیباگر:
python -m pdb hybrid_retrieval_reranker_debug.py
"""
import json
import numpy as np
import torch
from typing import List, Tuple, Dict
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# -------------------
# مدل‌ها و مسیر داده
# -------------------
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
DATA_PATH = "./output/sentences_vector.json"
def load_dataset(path: str) -> Tuple[List[str], np.ndarray]:
"""Load sentences and embeddings (float32)."""
with open(path, "r", encoding="utf-8") as f:
raw = json.load(f)
# اگر فایل dict باشه → به list تبدیل می‌کنیم
if isinstance(raw, dict):
raw = list(raw.values())
sentences, emb_list = [], []
for it in raw:
sent = it.get("sentence")
emb = it.get("embeddings")
if sent and isinstance(emb, (list, tuple)):
sentences.append(sent)
emb_list.append(emb)
if not sentences:
raise ValueError("Dataset invalid. Needs 'sentence' + 'embeddings'.")
emb_matrix = np.asarray(emb_list, dtype=np.float32)
return sentences, emb_matrix
class HybridRetrieverReranker:
def __init__(self, sentences: List[str], emb_matrix: np.ndarray,
dense_alpha: float = 0.6, device: str = None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.sentences = sentences
self.emb_matrix = emb_matrix
self.N = len(sentences)
# Dense
self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device)
self.embeddings_tensor = torch.from_numpy(self.emb_matrix).to(self.device)
# Sparse
self.vectorizer = TfidfVectorizer(
analyzer="word",
ngram_range=(1, 2),
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b"
)
self.tfidf_matrix = self.vectorizer.fit_transform(self.sentences)
# Reranker
self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL)
self.reranker = AutoModelForSequenceClassification.from_pretrained(
RERANKER_MODEL
).to(self.device)
self.dense_alpha = float(dense_alpha)
def dense_retrieve(self, query: str, top_k: int):
q_emb = self.embedder.encode(query, convert_to_tensor=True).to(self.device)
similars = util.cos_sim(q_emb, self.embeddings_tensor).squeeze(0)
top_scores, top_idx = torch.topk(similars, k=min(top_k, self.N))
return top_idx.tolist(), top_scores.detach().cpu().numpy()
def sparse_retrieve(self, query: str, top_k: int):
q_vec = self.vectorizer.transform([query])
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
idx = np.argpartition(-sims, kth=top_k-1)[:top_k]
idx = idx[np.argsort(-sims[idx])]
return idx.tolist(), sims[idx]
@staticmethod
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
if arr.size == 0:
return arr
a_min, a_max = arr.min(), arr.max()
if a_max - a_min < 1e-12:
return np.zeros_like(arr)
return (arr - a_min) / (a_max - a_min)
def fuse(self, d_idx, d_scores, s_idx, s_scores, pre_rerank_k: int):
d_norm = self._minmax_norm(d_scores.astype(np.float32))
s_norm = self._minmax_norm(s_scores.astype(np.float32))
d_map = {i: d for i, d in zip(d_idx, d_norm)}
s_map = {i: s for i, s in zip(s_idx, s_norm)}
fused = []
for i in set(d_idx) | set(s_idx):
score = self.dense_alpha * d_map.get(i, 0.0) + (1-self.dense_alpha) * s_map.get(i, 0.0)
fused.append((i, score))
fused.sort(key=lambda x: x[1], reverse=True)
return [i for i, _ in fused[:pre_rerank_k]]
def rerank(self, query: str, candidate_indices: List[int], final_k: int):
pairs = [(query, self.sentences[i]) for i in candidate_indices]
scores = []
for batch in [pairs[i:i+16] for i in range(0, len(pairs), 16)]:
inputs = self.tokenizer(batch, padding=True, truncation=True,
max_length=512, return_tensors="pt").to(self.device)
with torch.no_grad():
logits = self.reranker(**inputs).logits.view(-1)
scores.extend(logits.cpu().tolist())
items = sorted(zip(candidate_indices, scores), key=lambda x: x[1], reverse=True)
return items[:final_k]
def search(self, query: str, topk_dense=50, topk_sparse=50,
pre_rerank_k=50, final_k=5):
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
# import pdb; pdb.set_trace() # ← می‌تونی اینجا توقف کنی
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
# import pdb; pdb.set_trace() # ← بعد از fusion توقف
reranked = self.rerank(query, cand_idx, final_k)
# import pdb; pdb.set_trace() # ← بعد از rerank توقف
return [{"idx": i, "sentence": self.sentences[i], "rerank_score": score}
for i, score in reranked]
def main():
query = "انسان در فتنه ها باید چگونه عملی کند؟"
sentences, emb_matrix = load_dataset(DATA_PATH)
pipe = HybridRetrieverReranker(sentences, emb_matrix, dense_alpha=0.6)
results = pipe.search(query, topk_dense=30, topk_sparse=30, pre_rerank_k=30, final_k=5)
print("\nTop results:")
for i, r in enumerate(results, 1):
print(f"{i}. [score={r['rerank_score']:.4f}] {r['sentence']}")
if __name__ == "__main__":
import datetime
start = datetime.datetime.now()
main()
time2 = datetime.datetime.now()
print(time2 - start)
main()
time3 = datetime.datetime.now()
print(time3 - time2)
main()
time4 = datetime.datetime.now()
print(time4 - time3)
main()
time5 = datetime.datetime.now()
print(time5 - time4)
pass