nahj/reranker.py

97 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
DATA_PATH = "./output/sentences_vector.json"
# -----------------------------
# data fetching
# -----------------------------
with open(DATA_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
sentences = []
emb_list = []
for item in data:
if "sentence" in data[item] and "embeddings" in data[item] and isinstance(data[item]["embeddings"], list):
sentences.append(data[item]["sentence"])
emb_list.append(data[item]["embeddings"])
if not sentences:
raise ValueError("هیچ جمله/امبدینگی در فایل یافت نشد.")
# به float32 تبدیل می‌کنیم تا با خروجی SentenceTransformer هم‌خوان باشد
emb_matrix = np.asarray(emb_list, dtype=np.float32)
# -----------------------------
# device configuration
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_str = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# loading models
# -----------------------------
embedder = SentenceTransformer(EMBED_MODEL, device=device_str)
tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL)
reranker = AutoModelForSequenceClassification.from_pretrained(RERANKER_MODEL).to(device)
# تنسور امبدینگ‌های دیتاست روی همان دیوایس
embeddings_tensor = torch.from_numpy(emb_matrix).to(device) # (N, D) float32
# -----------------------------
# main function
# -----------------------------
def get_top_sentences(query: str, top_k: int = 20, final_k: int = 5):
# 1) embedding query
query_emb = embedder.encode(query, convert_to_tensor=True) # روی همان device مدل
if query_emb.device != device:
query_emb = query_emb.to(device)
# 2) شباهت کسینوسی (خروجی (1, N) → تبدیل به (N,))
sim_scores = util.cos_sim(query_emb, embeddings_tensor).squeeze(0) # (N,)
# تعداد واقعی k
k = min(top_k, sim_scores.size(0))
# 3) انتخاب k نتیجه نزدیک‌تر
topk_scores, topk_indices = torch.topk(sim_scores, k=k, largest=True, sorted=True)
idx_list = topk_indices.tolist() # لیست تخت از int
# extract top-k nearest sentences
candidate_sentences = [sentences[i] for i in idx_list]
# 4) rerank on top-k sentences
pairs = [(query, sent) for sent in candidate_sentences]
inputs = tokenizer(pairs, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
logits = reranker(**inputs).logits.view(-1)
# set final k sentences
final_k = min(final_k, len(candidate_sentences))
# 5) select final best nearer sentences based on rerank scores
best_idx = torch.topk(logits, k=final_k, largest=True, sorted=True).indices.tolist()
final_sentences = [candidate_sentences[i] for i in best_idx]
return final_sentences
# -----------------------------
# exeqution
# -----------------------------
if __name__ == "__main__":
q = "فرصت های خوب را نباید از دست داد"
q = "انسان در فتنه ها باید چگونه عملی کند؟"
results = get_top_sentences(q, top_k=20, final_k=5)
results_string = ''
for item in results:
results_string += "- " + item + '\n'
print(results_string)
print()