97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
import json
|
||
import torch
|
||
from sentence_transformers import SentenceTransformer, util
|
||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||
import numpy as np
|
||
|
||
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
||
DATA_PATH = "./output/sentences_vector.json"
|
||
|
||
# -----------------------------
|
||
# data fetching
|
||
# -----------------------------
|
||
with open(DATA_PATH, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
sentences = []
|
||
emb_list = []
|
||
for item in data:
|
||
if "sentence" in data[item] and "embeddings" in data[item] and isinstance(data[item]["embeddings"], list):
|
||
sentences.append(data[item]["sentence"])
|
||
emb_list.append(data[item]["embeddings"])
|
||
|
||
if not sentences:
|
||
raise ValueError("هیچ جمله/امبدینگی در فایل یافت نشد.")
|
||
|
||
# به float32 تبدیل میکنیم تا با خروجی SentenceTransformer همخوان باشد
|
||
emb_matrix = np.asarray(emb_list, dtype=np.float32)
|
||
|
||
# -----------------------------
|
||
# device configuration
|
||
# -----------------------------
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
# -----------------------------
|
||
# loading models
|
||
# -----------------------------
|
||
embedder = SentenceTransformer(EMBED_MODEL, device=device_str)
|
||
tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL)
|
||
reranker = AutoModelForSequenceClassification.from_pretrained(RERANKER_MODEL).to(device)
|
||
|
||
# تنسور امبدینگهای دیتاست روی همان دیوایس
|
||
embeddings_tensor = torch.from_numpy(emb_matrix).to(device) # (N, D) float32
|
||
|
||
# -----------------------------
|
||
# main function
|
||
# -----------------------------
|
||
def get_top_sentences(query: str, top_k: int = 20, final_k: int = 5):
|
||
# 1) embedding query
|
||
query_emb = embedder.encode(query, convert_to_tensor=True) # روی همان device مدل
|
||
if query_emb.device != device:
|
||
query_emb = query_emb.to(device)
|
||
|
||
# 2) شباهت کسینوسی (خروجی (1, N) → تبدیل به (N,))
|
||
sim_scores = util.cos_sim(query_emb, embeddings_tensor).squeeze(0) # (N,)
|
||
|
||
# تعداد واقعی k
|
||
k = min(top_k, sim_scores.size(0))
|
||
|
||
# 3) انتخاب k نتیجه نزدیکتر
|
||
topk_scores, topk_indices = torch.topk(sim_scores, k=k, largest=True, sorted=True)
|
||
idx_list = topk_indices.tolist() # لیست تخت از int
|
||
# extract top-k nearest sentences
|
||
candidate_sentences = [sentences[i] for i in idx_list]
|
||
|
||
# 4) rerank on top-k sentences
|
||
pairs = [(query, sent) for sent in candidate_sentences]
|
||
inputs = tokenizer(pairs, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
|
||
|
||
with torch.no_grad():
|
||
logits = reranker(**inputs).logits.view(-1)
|
||
|
||
# set final k sentences
|
||
final_k = min(final_k, len(candidate_sentences))
|
||
|
||
# 5) select final best nearer sentences based on rerank scores
|
||
best_idx = torch.topk(logits, k=final_k, largest=True, sorted=True).indices.tolist()
|
||
final_sentences = [candidate_sentences[i] for i in best_idx]
|
||
|
||
return final_sentences
|
||
|
||
# -----------------------------
|
||
# exeqution
|
||
# -----------------------------
|
||
if __name__ == "__main__":
|
||
q = "فرصت های خوب را نباید از دست داد"
|
||
q = "انسان در فتنه ها باید چگونه عملی کند؟"
|
||
results = get_top_sentences(q, top_k=20, final_k=5)
|
||
|
||
results_string = ''
|
||
for item in results:
|
||
results_string += "- " + item + '\n'
|
||
|
||
print(results_string)
|
||
print()
|