oil_domain_works/oil_domain_reranking_03.py
2025-07-03 09:59:12 +03:30

83 lines
2.9 KiB
Python

import json
from tqdm import tqdm
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer
from transformers import AutoModel # for pytorch
from transformers import TFAutoModelForTokenClassification # for tensorflow
from transformers import pipeline
import os
from datasets import Dataset, load_from_disk
print('start')
#---
# NOTE: for bug in dumping float in json
class NumpyFloatValuesEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.float32):
return float(obj)
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
#json.dumps(d, cls=NumpyFloatValuesEncoder)
#----
from FlagEmbedding import FlagReranker
import os
#os.environ['HUGGING_FACE_HUB_TOKEN'] = "hf_VeCSxLxSCVlt..."
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
def search_rerank(rule, sim_rules, rule_ids, rerank_k:int=4):
z_results = [[rule, i] for i in sim_rules]
# The scores map into 0-1 by set "normalize=True", which will apply sigmoid function to the score
scores = reranker.compute_score(z_results, normalize=True)
s_results = sorted(zip(scores, z_results, rule_ids), key=lambda x: x[0], reverse=True)
s_results2 = s_results[:rerank_k]
results = [[i[0], i[1][1], i[2]] for i in s_results2]
return results
#---
k = 10
related_data = []
print('loading data')
model_name_or_path = "similar_20_oil_BAAI_bge-m3"
#model_name_or_path = "similar_20_oil_jinaai_jina-embeddings-v3"
content_file = open(f'./mj/{model_name_or_path}.json', "r", encoding='utf-8')
oil_data = json.load(content_file)
for qan in tqdm(oil_data):
id = qan["id"]
rule_id = qan["rule_id"]
rule = qan["rule"]
retrieved_ids = []
retrieved_rule_ids = []
retrieved_rules = []
for relateds in zip(qan["retrieved_ids"], qan["retrieved_rule_ids"], qan["retrieved_rules"]):
if relateds[0] != id:
retrieved_ids.append(relateds[0])
retrieved_rule_ids.append(relateds[1])
retrieved_rules.append(relateds[2])
reranked = search_rerank(rule, retrieved_rules, retrieved_rule_ids, 10)
related_data.append({"rule_id":rule_id, "rule": rule, "retrieved_rules": reranked})
content_file.close()
remained = len(related_data)
print(remained)
##########
filename = "./mj/reranked_{}_oil_{}.json".format(k,model_name_or_path)
similarity_file = open(filename, "w", encoding='utf-8')
similarity_file.write(json.dumps(related_data, ensure_ascii=False, cls=NumpyFloatValuesEncoder, indent=4))#
similarity_file.close()
print('end')