83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
import json
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import torch
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoModel # for pytorch
|
|
from transformers import TFAutoModelForTokenClassification # for tensorflow
|
|
from transformers import pipeline
|
|
import os
|
|
from datasets import Dataset, load_from_disk
|
|
|
|
|
|
print('start')
|
|
#---
|
|
# NOTE: for bug in dumping float in json
|
|
class NumpyFloatValuesEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, np.float32):
|
|
return float(obj)
|
|
if isinstance(obj, np.integer):
|
|
return int(obj)
|
|
if isinstance(obj, np.floating):
|
|
return float(obj)
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
return json.JSONEncoder.default(self, obj)
|
|
#json.dumps(d, cls=NumpyFloatValuesEncoder)
|
|
#----
|
|
|
|
|
|
from FlagEmbedding import FlagReranker
|
|
import os
|
|
#os.environ['HUGGING_FACE_HUB_TOKEN'] = "hf_VeCSxLxSCVlt..."
|
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
|
|
|
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
|
|
|
|
|
|
|
|
|
|
def search_rerank(rule, sim_rules, rule_ids, rerank_k:int=4):
|
|
z_results = [[rule, i] for i in sim_rules]
|
|
# The scores map into 0-1 by set "normalize=True", which will apply sigmoid function to the score
|
|
scores = reranker.compute_score(z_results, normalize=True)
|
|
s_results = sorted(zip(scores, z_results, rule_ids), key=lambda x: x[0], reverse=True)
|
|
s_results2 = s_results[:rerank_k]
|
|
results = [[i[0], i[1][1], i[2]] for i in s_results2]
|
|
return results
|
|
|
|
|
|
#---
|
|
k = 10
|
|
related_data = []
|
|
print('loading data')
|
|
model_name_or_path = "similar_20_oil_BAAI_bge-m3"
|
|
#model_name_or_path = "similar_20_oil_jinaai_jina-embeddings-v3"
|
|
content_file = open(f'./mj/{model_name_or_path}.json', "r", encoding='utf-8')
|
|
oil_data = json.load(content_file)
|
|
for qan in tqdm(oil_data):
|
|
id = qan["id"]
|
|
rule_id = qan["rule_id"]
|
|
rule = qan["rule"]
|
|
retrieved_ids = []
|
|
retrieved_rule_ids = []
|
|
retrieved_rules = []
|
|
for relateds in zip(qan["retrieved_ids"], qan["retrieved_rule_ids"], qan["retrieved_rules"]):
|
|
if relateds[0] != id:
|
|
retrieved_ids.append(relateds[0])
|
|
retrieved_rule_ids.append(relateds[1])
|
|
retrieved_rules.append(relateds[2])
|
|
reranked = search_rerank(rule, retrieved_rules, retrieved_rule_ids, 10)
|
|
related_data.append({"rule_id":rule_id, "rule": rule, "retrieved_rules": reranked})
|
|
|
|
content_file.close()
|
|
remained = len(related_data)
|
|
print(remained)
|
|
##########
|
|
filename = "./mj/reranked_{}_oil_{}.json".format(k,model_name_or_path)
|
|
similarity_file = open(filename, "w", encoding='utf-8')
|
|
similarity_file.write(json.dumps(related_data, ensure_ascii=False, cls=NumpyFloatValuesEncoder, indent=4))#
|
|
similarity_file.close()
|
|
print('end') |