437 lines
17 KiB
Python
437 lines
17 KiB
Python
# بسم الله
|
||
|
||
|
||
"""
|
||
این سورس نوشته شده تا از فایل اکسل ، سوالات مربوط به نهج البلاغه رو به مدل ارائه بده
|
||
و پاسخ های ارائه شده توسط مدل (امبدیگ) رو با پاسخ موجود در فایل اکسل مقایسه میکنه
|
||
تا مشخص کنه که پاسخ صحیح در پاسخ های مدل هست یا خیر
|
||
نهایتا مشخص میکنه که مدل به چندتا از سوالات پرسیده شده پاسخ صحیح داده
|
||
"""
|
||
|
||
|
||
|
||
|
||
import torch
|
||
import faiss
|
||
import sqlite3
|
||
import datetime
|
||
import numpy as np
|
||
import pandas as pd
|
||
import data_model as dm
|
||
from typing import List, Tuple
|
||
from normalizer import cleaning
|
||
from sentence_transformers import SentenceTransformer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
|
||
import os
|
||
os.environ['HF_HUB_OFFLINE'] = '1'
|
||
|
||
|
||
# models_list = [
|
||
# "jinaai/jina-embeddings-v5-text-small",
|
||
# "BAAI/bge-m3",
|
||
# "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||
# "intfloat/multilingual-e5-small"
|
||
# ]
|
||
|
||
option_list = [
|
||
|
||
{"model_name":"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2","model_faiss":"./data-faiss/faiss_index_nahj.index","samples_numbers":[10,100,150,200,300],"normal_embedder":True},
|
||
{"model_name":"intfloat/multilingual-e5-small","model_faiss":"./data-faiss/faiss_index_multilingual-e5-small.index","samples_numbers":[10,100,150,200,300],"normal_embedder":True},
|
||
{"model_name":"jinaai/jina-embeddings-v5-text-small","model_faiss":"./data-faiss/faiss_index_nahj_jina-embeddings-v5-text-small.index","samples_numbers":[10,100,150,200,300],"normal_embedder":False},
|
||
{"model_name":"BAAI/bge-m3","model_faiss":"./data-faiss/faiss_index_nahj_bge_m3.index","samples_numbers":[10,100,150,200,300],"normal_embedder":True}
|
||
]
|
||
|
||
|
||
normal_embedder = True
|
||
|
||
MODEL_PATH = "intfloat/multilingual-e5-small"
|
||
# sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
|
||
# BAAI/bge-m3
|
||
# intfloat/multilingual-e5-small
|
||
# jinaai/jina-embeddings-v5-text-small
|
||
|
||
# RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
||
|
||
FAISS_INDEX_PATH = "./data-faiss/faiss_index_multilingual-e5-small.index"
|
||
# ./data-faiss/faiss_index_nahj.index
|
||
# ./data-faiss/faiss_index_nahj_bge_m3.index
|
||
# ./data-faiss/faiss_index_multilingual-e5-small.index
|
||
# ./data-faiss/faiss_index_nahj_jina-embeddings-v5-text-small.index
|
||
|
||
FAISS_METADATA_PATH = "./data-faiss/faiss_index_nahj_metadata.json"
|
||
|
||
normal_embedder = True
|
||
conn = sqlite3.connect('./db/nahj.db')
|
||
cursor = conn.cursor()
|
||
|
||
|
||
# خواندن فایل اکسل
|
||
df = pd.read_excel('./sources/nahj_question.xlsx')
|
||
|
||
|
||
answers_and_questions = df[['خطبه', 'سوال']].head(101).values.tolist()
|
||
|
||
# خواندن تمامی متادیتا ها از دیتابیس
|
||
def read_records():
|
||
metadata = dm.get_all_data()
|
||
ids, part_ids, context_ids, large_titles, normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types = [], [], [], [], [], [], [], [], [], []
|
||
for item in metadata:
|
||
|
||
ids.append(item['id'])
|
||
context_ids.append(item['context_id'])
|
||
part_ids.append(item['part_id'])
|
||
titles.append(item['title'])
|
||
large_titles.append(item['large_title'])
|
||
normalized_sentences.append(item['normalized_sentence'])
|
||
urls.append(item['url'])
|
||
types.append(item['types'])
|
||
arabic_texts.append(item['arabic_text'])
|
||
Interpretation_links.append(item['interpretation_links'])
|
||
|
||
|
||
return ids, part_ids,context_ids,large_titles,normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types
|
||
|
||
|
||
def load_faiss_index(index_path: str, metadata_path: str):
|
||
index = faiss.read_index(index_path)
|
||
|
||
ids, part_ids,context_ids,large_titles,normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types = read_records()
|
||
|
||
return ids, part_ids,context_ids,large_titles,normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types, index
|
||
|
||
class HybridRetrieverReranker:
|
||
__slots__ = (
|
||
"device", "ids","part_ids","context_ids","large_titles","sentences","titles","urls","arabic_texts","Interpretation_links","types", "N", "embedder", "faiss_index", "vectorizer", "tfidf_matrix", "tokenizer", "reranker", "dense_alpha"
|
||
)
|
||
|
||
# ids, part_ids,context_ids,large_titles,sentences , titles, urls, arabic_texts, Interpretation_links, types,
|
||
def __init__(self, ids: List[str], part_ids: List[str], context_ids: List[str], large_titles: List[str], sentences: List[str], titles: List[str], urls: List[str], arabic_texts: List[str], Interpretation_links: List[str], types: List[str], faiss_index,
|
||
dense_alpha: float = 0.6, device: str = None, model_path = MODEL_PATH):
|
||
|
||
if device is None:
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
self.device = device
|
||
|
||
self.ids = ids
|
||
self.part_ids = part_ids
|
||
self.context_ids = context_ids
|
||
self.large_titles = large_titles
|
||
self.sentences = sentences
|
||
self.titles = titles
|
||
self.urls = urls
|
||
self.arabic_texts = arabic_texts
|
||
self.Interpretation_links = Interpretation_links
|
||
self.types = types
|
||
self.faiss_index = faiss_index
|
||
self.N = len(sentences)
|
||
|
||
# --- Dense Embedder ---
|
||
print("Loading SentenceTransformer model ...")
|
||
|
||
if normal_embedder == False:
|
||
self.embedder = SentenceTransformer(model_path\
|
||
,trust_remote_code=True \
|
||
,model_kwargs={'default_task': 'retrieval'}\
|
||
,device=self.device)
|
||
|
||
else:
|
||
self.embedder = SentenceTransformer(model_path, device=self.device)
|
||
# embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device= self.device)
|
||
# embedder.save(MODEL_PATH)
|
||
|
||
# --- Sparse (TF-IDF) ---
|
||
self.vectorizer = TfidfVectorizer(
|
||
analyzer="word",
|
||
ngram_range=(1, 2),
|
||
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
|
||
)
|
||
self.tfidf_matrix = self.vectorizer.fit_transform(self.sentences)
|
||
|
||
# --- Reranker ---
|
||
# self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL, use_fast=True, local_files_only= True)
|
||
# self.reranker = AutoModelForSequenceClassification.from_pretrained(
|
||
# RERANKER_MODEL, local_files_only= True
|
||
# ).to(self.device)
|
||
|
||
self.dense_alpha = float(dense_alpha)
|
||
|
||
# --- Dense retrieval using FastText ---
|
||
# def _embed_sentence(self, text: str) -> np.ndarray:
|
||
# tokens = text.split()
|
||
# vectors = [self.embedder.get_word_vector(tok) for tok in tokens if tok.strip()]
|
||
# if not vectors:
|
||
# return np.zeros((self.embedder.get_dimension(),), dtype=np.float32)
|
||
# return np.mean(vectors, axis=0).astype(np.float32)
|
||
|
||
def _embed_sentence(self, text: str) -> np.ndarray:
|
||
embedding = self.embedder.encode(
|
||
text,
|
||
convert_to_numpy=True,
|
||
normalize_embeddings=True # مهم برای cosine similarity
|
||
)
|
||
return embedding.astype(np.float32)
|
||
|
||
def dense_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([], dtype=np.float32)
|
||
q_emb = self._embed_sentence(query)
|
||
D, I = self.faiss_index.search(np.expand_dims(q_emb, axis=0), top_k)
|
||
return I[0].tolist(), D[0]
|
||
|
||
# --- Sparse ---
|
||
def sparse_retrieve(self, query: str, top_k: int):
|
||
if top_k <= 0:
|
||
return [], np.array([], dtype=np.float32)
|
||
k = min(top_k, self.N)
|
||
q_vec = self.vectorizer.transform([query])
|
||
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
|
||
idx = np.argpartition(-sims, kth=k - 1)[:k]
|
||
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
|
||
return idx.tolist(), sims[idx]
|
||
|
||
# --- Normalization ---
|
||
@staticmethod
|
||
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
|
||
if arr.size == 0:
|
||
return arr
|
||
a_min = arr.min()
|
||
a_max = arr.max()
|
||
rng = a_max - a_min
|
||
if rng < 1e-12:
|
||
return np.zeros_like(arr)
|
||
return (arr - a_min) / rng
|
||
|
||
# --- Fusion (RRF) ---
|
||
def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
|
||
combined = {}
|
||
for rank, idx in enumerate(d_idx):
|
||
score = 1.0 / (k_rrf + rank)
|
||
combined[idx] = combined.get(idx, 0) + score
|
||
for rank, idx in enumerate(s_idx):
|
||
score = 1.0 / (k_rrf + rank)
|
||
combined[idx] = combined.get(idx, 0) + score
|
||
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
|
||
cand_idx = [item[0] for item in sorted_items[:top_k]]
|
||
return cand_idx
|
||
|
||
# --- Rerank ---
|
||
def rerank(self, query: str, candidate_indices: List[int], passages: List[str], final_k: int) -> List[Tuple[int, float]]:
|
||
if final_k <= 0 or not candidate_indices:
|
||
return []
|
||
texts = [query] * len(candidate_indices)
|
||
pairs = passages
|
||
scores: List[float] = []
|
||
|
||
def _iter_batches(max_bs: int):
|
||
bs = max_bs
|
||
while bs >= 16:
|
||
try:
|
||
with torch.inference_mode():
|
||
for start in range(0, len(pairs), bs):
|
||
batch_texts = texts[start:start + bs]
|
||
batch_pairs = pairs[start:start + bs]
|
||
inputs = self.tokenizer(
|
||
batch_texts,
|
||
batch_pairs,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512,
|
||
return_tensors="pt",
|
||
).to(self.device)
|
||
logits = self.reranker(**inputs).logits.view(-1)
|
||
scores.extend(logits.detach().cpu().tolist())
|
||
return True
|
||
except torch.cuda.OutOfMemoryError:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
bs //= 2
|
||
return False
|
||
|
||
success = _iter_batches(max_bs=64)
|
||
if not success:
|
||
raise RuntimeError("Reranker failed due to CUDA OOM.")
|
||
reranked = sorted(
|
||
zip(candidate_indices, scores),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:final_k]
|
||
return reranked
|
||
|
||
def get_passages(self, cand_idx, sentences_list):
|
||
passages = [sentences_list[idx] for idx in cand_idx]
|
||
return passages
|
||
|
||
# --- Search ---
|
||
def search(self, query: str, sentence_list, topk_dense=50, topk_sparse=50,
|
||
pre_rerank_k=50, final_k=10):
|
||
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
|
||
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
|
||
# print('---------- dddd scores ----------')
|
||
# for item in d_scores:
|
||
# print(item)
|
||
# print("---------- ssss scores ----------")
|
||
# for item in s_scores:
|
||
# print(item)
|
||
pre_rerank_k = final_k
|
||
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
|
||
|
||
# passages = self.get_passages(cand_idx, sentence_list)
|
||
|
||
# reranked = self.rerank(query, cand_idx, passages, final_k)
|
||
# return [
|
||
# {"idx": i, "content": self.sentence_list[i], "rerank_score": score}
|
||
# for i, score in reranked
|
||
# ]
|
||
|
||
|
||
# return [
|
||
# {"idx": i, "content": self.sentence_list[i], "rerank_score": score}
|
||
# for i, score in reranked
|
||
# ]
|
||
return [
|
||
{"idx": i, "content": self.sentences[i]}
|
||
for i in cand_idx
|
||
]
|
||
|
||
|
||
|
||
def single_query(query: str , samples_number):
|
||
|
||
query = cleaning(query)
|
||
|
||
# تنظیم final_k در اینجا مشخص میکنه پاسخ های مدل چندتا باشه
|
||
retrived_sections = pipe.search(query, sentences, topk_dense=100, topk_sparse=100, pre_rerank_k=100, final_k=samples_number)
|
||
|
||
retrived_sections_list = []
|
||
final_similars_text = ''
|
||
id_list = ''
|
||
for i, row in enumerate(retrived_sections):
|
||
|
||
part_id = part_ids[row['idx']]
|
||
row["part_id"] = part_id
|
||
row["url"] = urls[row['idx']]
|
||
|
||
title_value = '{' + str(large_titles[row['idx']]) + '}'
|
||
result = f"{i+1}. عنوان: {title_value}\n{row['content']}\n\n"
|
||
final_similars_text += ''.join(result)
|
||
id_list += f"{ids[row['idx']]}\n"
|
||
retrived_sections_list.append(row)
|
||
return id_list, final_similars_text, retrived_sections_list
|
||
|
||
|
||
def get_passages_by_paragraphs(retrived_sections_list):
|
||
"""
|
||
بازسازی متن های مشابه بر اساس پاراگراف آنها
|
||
"""
|
||
data = dm.get_all_data()
|
||
|
||
final_passages = ''
|
||
for item in retrived_sections_list:
|
||
filtered_data = {}
|
||
for row in data:
|
||
if row['part_id'] == item['part_id']:
|
||
filtered_data[row['id']] = row
|
||
title = row['large_title']
|
||
|
||
# مرتب سازی بر اساس ترتیب جمله در پاراگراف
|
||
sorted_data = dict(sorted(filtered_data.items()))
|
||
|
||
paragraph = f'{title}:\n'
|
||
for key, value in sorted_data.items():
|
||
paragraph += ''.join(f'{value['normalized_sentence']}. ')
|
||
|
||
final_passages += ''.join(f'{paragraph.strip()}\n\n')
|
||
|
||
return final_passages
|
||
|
||
# متد اصلی جست و جو
|
||
def bale_search(query,samples_number=10):
|
||
start = datetime.datetime.now()
|
||
id_list, result_passages, retrived_sections_list = single_query(query,samples_number)
|
||
related_paragraphs = get_passages_by_paragraphs(retrived_sections_list)
|
||
|
||
end_retrive = datetime.datetime.now()
|
||
print('-'*40)
|
||
print(f'retrive duration: {(end_retrive - start).total_seconds()}')
|
||
# پاسخ حداکثر 300 کلمه باشد.
|
||
|
||
refrences = ''
|
||
|
||
print('---------------------------------------------')
|
||
print(f'full duration: {(datetime.datetime.now() - start).total_seconds()}')
|
||
print('---------------------------------------------')
|
||
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
|
||
return id_list.split("\n")
|
||
|
||
|
||
final_result = ''
|
||
|
||
for option in option_list:
|
||
|
||
if option['normal_embedder'] == False:
|
||
normal_embedder = False
|
||
else :
|
||
normal_embedder = True
|
||
|
||
ids, part_ids, context_ids, large_titles, sentences, titles, urls, arabic_texts, Interpretation_links, types, faiss_index = load_faiss_index(option['model_faiss'], FAISS_METADATA_PATH)
|
||
|
||
pipe = HybridRetrieverReranker(ids, part_ids,context_ids,large_titles,sentences , titles, urls, arabic_texts, Interpretation_links, types, faiss_index, dense_alpha=0.6 , model_path=option['model_name'])
|
||
|
||
final_result+="\n--------------------------------------------------------\n"
|
||
final_result+=f"model : {option['model_name']}\n"
|
||
|
||
for num in option['samples_numbers'] :
|
||
|
||
|
||
final_answers_and_questions = []
|
||
true_n = 0
|
||
sen_n = 0
|
||
for s in answers_and_questions :
|
||
# حلقه اصلی برای دریافت سوال و جواب ها از اکسل
|
||
# ارسال آنها به مدل و صحت سنجی پاسخ آن
|
||
sen_n+=1
|
||
print(sen_n)
|
||
find = False
|
||
answer = f"خطبه {str(int(s[0]))}"
|
||
question = s[1]
|
||
bot_answers = bale_search(question ,samples_number = num)
|
||
|
||
titles_list = []
|
||
for _id in bot_answers:
|
||
if _id == '' :
|
||
continue
|
||
cursor.execute("SELECT * FROM speeches WHERE id = ?", (_id,))
|
||
result = cursor.fetchall()
|
||
answer_title = result[0][3]
|
||
if answer_title == answer :
|
||
find = True
|
||
|
||
if find == True :
|
||
true_n += 1
|
||
s.append(find)
|
||
s[0] = int(s[0])
|
||
final_answers_and_questions.append(s)
|
||
|
||
|
||
final_result+="_____________________\n"
|
||
final_result+=f"findes speechs number: {true_n}/{len(answers_and_questions)}\n"
|
||
final_result+=f"samples number : {num}\n"
|
||
|
||
final_result+="--------------------------------------------------------\n"
|
||
|
||
# پرینت کردن اینکه از تعداد سوالات پرسیده شده به چندتای اونها جواب صحیح داده شده
|
||
|
||
print(final_result)
|
||
# print("--------------------------------------------------------")
|
||
# print(f"model : {option['model_name']}")
|
||
# print(f"findes speechs number: {true_n}/{len(answers_and_questions)}")
|
||
# print(f"samples number : {num}")
|
||
# print("--------------------------------------------------------")
|
||
|
||
|
||
|
||
|
||
|