Flair_NER/find_law.py

89 lines
3.1 KiB
Python

from sentence_transformers import SentenceTransformer, util
# from inference import inference_main
from funcs import read_from_json
#from general_functions import normalize_content
#model_path = './paraphrase-multilingual-mpnet-base-v2-1401-07-30'
#model_path = '/home/gpu/NLP/MLM/MODELS/training_stsbenchmark-HooshvareLab-bert-fa-base-uncased-finetuned-2-pt-2024-02-20_16-55-15'
def find_similarity(value_1, value_2):
value_1 = [value_1]
value_2 = [value_2]
# value_1 = value_1.lstrip('tensor(')
# value_1 = value_1.rstrip(', device=\'cuda:0\')')
# # value_1 = torch.tensor(eval(value_1))
# # print(value_1)
# # # value_2 = value_2.lstrip('tensor(')
# # # value_2 = value_2.rstrip(', device=\'cuda:0\')')
# # value_2 = torch.tensor(eval(value_2))
# # print(value_2)
# اگر دستگاه GPU موجود باشد، آن را انتخاب کنید، در غیر این صورت از CPU استفاده کنید
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# value_1 = torch.tensor(value_1, device=device)
# value_2 = torch.tensor(value_2, device=device)
# Compute cosine-similarities
cosine_scores = util.cos_sim(value_1, value_2)
# print(cosine_scores)
return cosine_scores
def get_embedding(text):
#text = cleaning(text)
embedded_value = encoder.encode(text, convert_to_tensor=True)
return embedded_value
def find_related_law(detected_value):
similarity_arr = []
detected_value = pre_process(detected_value)
# حذف عنوان قانون از ابتدای توکن به منظور یکدست سازی با امبدینگ های موجود در جیسون
detected_value = detected_value.lstrip('قانون').strip()
# print(detected_value)
detected_value_embedding = get_embedding(detected_value)
x = 1
for law in law_dict:
caption_embedding = law['caption_embedding']
similarity_value = find_similarity(detected_value_embedding.tolist(), caption_embedding)
similarity_arr.append({'law_id':law['id'], 'similarity':similarity_value, 'caption':law['caption']})
# if x == 1:
# print(f'{datetime.now()} oooooooooooooooooooooooooooooooooooooooooooooooooooooooooo')
# if x%1000 == 0:
# print(f'law title number {str(x)} is reading ...')
try:
x += 1
except:
pass
sorted_similarity_arr = sorted(similarity_arr, key=lambda x: x['similarity'],reverse= True)
found_law = sorted_similarity_arr[0]
print(found_law['caption'])
return found_law
def pre_process(text):
#text = normalize_content(text)
return text
if __name__ == "__main__":
model_path = '/home/gpu/tnlp/jokar/Models/HooshvareLab-bert-fa-base-uncased-finetuned-2-pt'
encoder = SentenceTransformer(model_path)
law_dict = read_from_json('./data/law_title.json')
found_law = find_related_law('قانون خانواده')
print(found_law['caption'])
print()
# method()
# print(' operation finished!')
# print(datetime.now())