nahj_rag/Automatic_answer_checker.py
2026-04-30 19:16:50 +03:30

437 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# بسم الله
"""
این سورس نوشته شده تا از فایل اکسل ، سوالات مربوط به نهج البلاغه رو به مدل ارائه بده
و پاسخ های ارائه شده توسط مدل (امبدیگ) رو با پاسخ موجود در فایل اکسل مقایسه میکنه
تا مشخص کنه که پاسخ صحیح در پاسخ های مدل هست یا خیر
نهایتا مشخص میکنه که مدل به چندتا از سوالات پرسیده شده پاسخ صحیح داده
"""
import torch
import faiss
import sqlite3
import datetime
import numpy as np
import pandas as pd
import data_model as dm
from typing import List, Tuple
from normalizer import cleaning
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import os
os.environ['HF_HUB_OFFLINE'] = '1'
# models_list = [
# "jinaai/jina-embeddings-v5-text-small",
# "BAAI/bge-m3",
# "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
# "intfloat/multilingual-e5-small"
# ]
option_list = [
{"model_name":"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2","model_faiss":"./data-faiss/faiss_index_nahj.index","samples_numbers":[10,100,150,200,300],"normal_embedder":True},
{"model_name":"intfloat/multilingual-e5-small","model_faiss":"./data-faiss/faiss_index_multilingual-e5-small.index","samples_numbers":[10,100,150,200,300],"normal_embedder":True},
{"model_name":"jinaai/jina-embeddings-v5-text-small","model_faiss":"./data-faiss/faiss_index_nahj_jina-embeddings-v5-text-small.index","samples_numbers":[10,100,150,200,300],"normal_embedder":False},
{"model_name":"BAAI/bge-m3","model_faiss":"./data-faiss/faiss_index_nahj_bge_m3.index","samples_numbers":[10,100,150,200,300],"normal_embedder":True}
]
normal_embedder = True
MODEL_PATH = "intfloat/multilingual-e5-small"
# sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
# BAAI/bge-m3
# intfloat/multilingual-e5-small
# jinaai/jina-embeddings-v5-text-small
# RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
FAISS_INDEX_PATH = "./data-faiss/faiss_index_multilingual-e5-small.index"
# ./data-faiss/faiss_index_nahj.index
# ./data-faiss/faiss_index_nahj_bge_m3.index
# ./data-faiss/faiss_index_multilingual-e5-small.index
# ./data-faiss/faiss_index_nahj_jina-embeddings-v5-text-small.index
FAISS_METADATA_PATH = "./data-faiss/faiss_index_nahj_metadata.json"
normal_embedder = True
conn = sqlite3.connect('./db/nahj.db')
cursor = conn.cursor()
# خواندن فایل اکسل
df = pd.read_excel('./sources/nahj_question.xlsx')
answers_and_questions = df[['خطبه', 'سوال']].head(101).values.tolist()
# خواندن تمامی متادیتا ها از دیتابیس
def read_records():
metadata = dm.get_all_data()
ids, part_ids, context_ids, large_titles, normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types = [], [], [], [], [], [], [], [], [], []
for item in metadata:
ids.append(item['id'])
context_ids.append(item['context_id'])
part_ids.append(item['part_id'])
titles.append(item['title'])
large_titles.append(item['large_title'])
normalized_sentences.append(item['normalized_sentence'])
urls.append(item['url'])
types.append(item['types'])
arabic_texts.append(item['arabic_text'])
Interpretation_links.append(item['interpretation_links'])
return ids, part_ids,context_ids,large_titles,normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types
def load_faiss_index(index_path: str, metadata_path: str):
index = faiss.read_index(index_path)
ids, part_ids,context_ids,large_titles,normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types = read_records()
return ids, part_ids,context_ids,large_titles,normalized_sentences, titles, urls, arabic_texts, Interpretation_links, types, index
class HybridRetrieverReranker:
__slots__ = (
"device", "ids","part_ids","context_ids","large_titles","sentences","titles","urls","arabic_texts","Interpretation_links","types", "N", "embedder", "faiss_index", "vectorizer", "tfidf_matrix", "tokenizer", "reranker", "dense_alpha"
)
# ids, part_ids,context_ids,large_titles,sentences , titles, urls, arabic_texts, Interpretation_links, types,
def __init__(self, ids: List[str], part_ids: List[str], context_ids: List[str], large_titles: List[str], sentences: List[str], titles: List[str], urls: List[str], arabic_texts: List[str], Interpretation_links: List[str], types: List[str], faiss_index,
dense_alpha: float = 0.6, device: str = None, model_path = MODEL_PATH):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.ids = ids
self.part_ids = part_ids
self.context_ids = context_ids
self.large_titles = large_titles
self.sentences = sentences
self.titles = titles
self.urls = urls
self.arabic_texts = arabic_texts
self.Interpretation_links = Interpretation_links
self.types = types
self.faiss_index = faiss_index
self.N = len(sentences)
# --- Dense Embedder ---
print("Loading SentenceTransformer model ...")
if normal_embedder == False:
self.embedder = SentenceTransformer(model_path\
,trust_remote_code=True \
,model_kwargs={'default_task': 'retrieval'}\
,device=self.device)
else:
self.embedder = SentenceTransformer(model_path, device=self.device)
# embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device= self.device)
# embedder.save(MODEL_PATH)
# --- Sparse (TF-IDF) ---
self.vectorizer = TfidfVectorizer(
analyzer="word",
ngram_range=(1, 2),
token_pattern=r"(?u)\b[\w\u0600-\u06FF]{2,}\b",
)
self.tfidf_matrix = self.vectorizer.fit_transform(self.sentences)
# --- Reranker ---
# self.tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL, use_fast=True, local_files_only= True)
# self.reranker = AutoModelForSequenceClassification.from_pretrained(
# RERANKER_MODEL, local_files_only= True
# ).to(self.device)
self.dense_alpha = float(dense_alpha)
# --- Dense retrieval using FastText ---
# def _embed_sentence(self, text: str) -> np.ndarray:
# tokens = text.split()
# vectors = [self.embedder.get_word_vector(tok) for tok in tokens if tok.strip()]
# if not vectors:
# return np.zeros((self.embedder.get_dimension(),), dtype=np.float32)
# return np.mean(vectors, axis=0).astype(np.float32)
def _embed_sentence(self, text: str) -> np.ndarray:
embedding = self.embedder.encode(
text,
convert_to_numpy=True,
normalize_embeddings=True # مهم برای cosine similarity
)
return embedding.astype(np.float32)
def dense_retrieve(self, query: str, top_k: int):
if top_k <= 0:
return [], np.array([], dtype=np.float32)
q_emb = self._embed_sentence(query)
D, I = self.faiss_index.search(np.expand_dims(q_emb, axis=0), top_k)
return I[0].tolist(), D[0]
# --- Sparse ---
def sparse_retrieve(self, query: str, top_k: int):
if top_k <= 0:
return [], np.array([], dtype=np.float32)
k = min(top_k, self.N)
q_vec = self.vectorizer.transform([query])
sims = cosine_similarity(q_vec, self.tfidf_matrix).ravel()
idx = np.argpartition(-sims, kth=k - 1)[:k]
idx = idx[np.argsort(-sims[idx], kind="mergesort")]
return idx.tolist(), sims[idx]
# --- Normalization ---
@staticmethod
def _minmax_norm(arr: np.ndarray) -> np.ndarray:
if arr.size == 0:
return arr
a_min = arr.min()
a_max = arr.max()
rng = a_max - a_min
if rng < 1e-12:
return np.zeros_like(arr)
return (arr - a_min) / rng
# --- Fusion (RRF) ---
def fuse(self, d_idx, d_scores, s_idx, s_scores, top_k=50, k_rrf=60):
combined = {}
for rank, idx in enumerate(d_idx):
score = 1.0 / (k_rrf + rank)
combined[idx] = combined.get(idx, 0) + score
for rank, idx in enumerate(s_idx):
score = 1.0 / (k_rrf + rank)
combined[idx] = combined.get(idx, 0) + score
sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
cand_idx = [item[0] for item in sorted_items[:top_k]]
return cand_idx
# --- Rerank ---
def rerank(self, query: str, candidate_indices: List[int], passages: List[str], final_k: int) -> List[Tuple[int, float]]:
if final_k <= 0 or not candidate_indices:
return []
texts = [query] * len(candidate_indices)
pairs = passages
scores: List[float] = []
def _iter_batches(max_bs: int):
bs = max_bs
while bs >= 16:
try:
with torch.inference_mode():
for start in range(0, len(pairs), bs):
batch_texts = texts[start:start + bs]
batch_pairs = pairs[start:start + bs]
inputs = self.tokenizer(
batch_texts,
batch_pairs,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
logits = self.reranker(**inputs).logits.view(-1)
scores.extend(logits.detach().cpu().tolist())
return True
except torch.cuda.OutOfMemoryError:
if torch.cuda.is_available():
torch.cuda.empty_cache()
bs //= 2
return False
success = _iter_batches(max_bs=64)
if not success:
raise RuntimeError("Reranker failed due to CUDA OOM.")
reranked = sorted(
zip(candidate_indices, scores),
key=lambda x: x[1],
reverse=True
)[:final_k]
return reranked
def get_passages(self, cand_idx, sentences_list):
passages = [sentences_list[idx] for idx in cand_idx]
return passages
# --- Search ---
def search(self, query: str, sentence_list, topk_dense=50, topk_sparse=50,
pre_rerank_k=50, final_k=10):
d_idx, d_scores = self.dense_retrieve(query, topk_dense)
s_idx, s_scores = self.sparse_retrieve(query, topk_sparse)
# print('---------- dddd scores ----------')
# for item in d_scores:
# print(item)
# print("---------- ssss scores ----------")
# for item in s_scores:
# print(item)
pre_rerank_k = final_k
cand_idx = self.fuse(d_idx, d_scores, s_idx, s_scores, pre_rerank_k)
# passages = self.get_passages(cand_idx, sentence_list)
# reranked = self.rerank(query, cand_idx, passages, final_k)
# return [
# {"idx": i, "content": self.sentence_list[i], "rerank_score": score}
# for i, score in reranked
# ]
# return [
# {"idx": i, "content": self.sentence_list[i], "rerank_score": score}
# for i, score in reranked
# ]
return [
{"idx": i, "content": self.sentences[i]}
for i in cand_idx
]
def single_query(query: str , samples_number):
query = cleaning(query)
# تنظیم final_k در اینجا مشخص میکنه پاسخ های مدل چندتا باشه
retrived_sections = pipe.search(query, sentences, topk_dense=100, topk_sparse=100, pre_rerank_k=100, final_k=samples_number)
retrived_sections_list = []
final_similars_text = ''
id_list = ''
for i, row in enumerate(retrived_sections):
part_id = part_ids[row['idx']]
row["part_id"] = part_id
row["url"] = urls[row['idx']]
title_value = '{' + str(large_titles[row['idx']]) + '}'
result = f"{i+1}. عنوان: {title_value}\n{row['content']}\n\n"
final_similars_text += ''.join(result)
id_list += f"{ids[row['idx']]}\n"
retrived_sections_list.append(row)
return id_list, final_similars_text, retrived_sections_list
def get_passages_by_paragraphs(retrived_sections_list):
"""
بازسازی متن های مشابه بر اساس پاراگراف آنها
"""
data = dm.get_all_data()
final_passages = ''
for item in retrived_sections_list:
filtered_data = {}
for row in data:
if row['part_id'] == item['part_id']:
filtered_data[row['id']] = row
title = row['large_title']
# مرتب سازی بر اساس ترتیب جمله در پاراگراف
sorted_data = dict(sorted(filtered_data.items()))
paragraph = f'{title}:\n'
for key, value in sorted_data.items():
paragraph += ''.join(f'{value['normalized_sentence']}. ')
final_passages += ''.join(f'{paragraph.strip()}\n\n')
return final_passages
# متد اصلی جست و جو
def bale_search(query,samples_number=10):
start = datetime.datetime.now()
id_list, result_passages, retrived_sections_list = single_query(query,samples_number)
related_paragraphs = get_passages_by_paragraphs(retrived_sections_list)
end_retrive = datetime.datetime.now()
print('-'*40)
print(f'retrive duration: {(end_retrive - start).total_seconds()}')
# پاسخ حداکثر 300 کلمه باشد.
refrences = ''
print('---------------------------------------------')
print(f'full duration: {(datetime.datetime.now() - start).total_seconds()}')
print('---------------------------------------------')
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
return id_list.split("\n")
final_result = ''
for option in option_list:
if option['normal_embedder'] == False:
normal_embedder = False
else :
normal_embedder = True
ids, part_ids, context_ids, large_titles, sentences, titles, urls, arabic_texts, Interpretation_links, types, faiss_index = load_faiss_index(option['model_faiss'], FAISS_METADATA_PATH)
pipe = HybridRetrieverReranker(ids, part_ids,context_ids,large_titles,sentences , titles, urls, arabic_texts, Interpretation_links, types, faiss_index, dense_alpha=0.6 , model_path=option['model_name'])
final_result+="\n--------------------------------------------------------\n"
final_result+=f"model : {option['model_name']}\n"
for num in option['samples_numbers'] :
final_answers_and_questions = []
true_n = 0
sen_n = 0
for s in answers_and_questions :
# حلقه اصلی برای دریافت سوال و جواب ها از اکسل
# ارسال آنها به مدل و صحت سنجی پاسخ آن
sen_n+=1
print(sen_n)
find = False
answer = f"خطبه {str(int(s[0]))}"
question = s[1]
bot_answers = bale_search(question ,samples_number = num)
titles_list = []
for _id in bot_answers:
if _id == '' :
continue
cursor.execute("SELECT * FROM speeches WHERE id = ?", (_id,))
result = cursor.fetchall()
answer_title = result[0][3]
if answer_title == answer :
find = True
if find == True :
true_n += 1
s.append(find)
s[0] = int(s[0])
final_answers_and_questions.append(s)
final_result+="_____________________\n"
final_result+=f"findes speechs number: {true_n}/{len(answers_and_questions)}\n"
final_result+=f"samples number : {num}\n"
final_result+="--------------------------------------------------------\n"
# پرینت کردن اینکه از تعداد سوالات پرسیده شده به چندتای اونها جواب صحیح داده شده
print(final_result)
# print("--------------------------------------------------------")
# print(f"model : {option['model_name']}")
# print(f"findes speechs number: {true_n}/{len(answers_and_questions)}")
# print(f"samples number : {num}")
# print("--------------------------------------------------------")