Flair_NER/ner_proccess.py

234 lines
9.6 KiB
Python
Raw Normal View History

2024-09-18 16:35:06 +00:00
# -*- coding: utf-8 -*-
"""Flair_NER_Inference .ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1e-Q1bzMvm1mtRuxwnZBeXRfb-E39hxKu
"""
from general_functions import normalize_content
from funcs import separated_date_format_finder
from flair.data import Sentence
from flair.models import SequenceTagger
from flair.nn import Classifier
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from transformers import AutoTokenizer
from flair.embeddings import TransformerWordEmbeddings
# from find_law import find_related_law
#from datetime import datetime
# from train import model
trained_model = 'unknown'
#model = "./taggers/final-model.pt"
model = "/home/gpu/tnlp/jokar/Flair_NER/taggers/final-model.pt"
print('model read')
tagger = SequenceTagger.load(model)
print('tagger initialized')
def save_to_file(result):
with open('./data/test_result.txt', 'a+', encoding='utf-8') as file:
previous_result = ''
try:
previous_result = file.read()
except:
pass
file.write(previous_result
+ '\n' + 50*'*'
+ '\n' + result
+ '\n' + 50*'*' + '\n')
def read_file():
with open('./data/law.txt', 'r', encoding='utf-8') as file:
text = ''
try:
text = str(file.read())
except:
pass
return text
# import nltk
import re
# from num2words import num2words
def find_ner_values_in_text(text, ner_values):
text_temp = text
text_tokens = text.split()
ner_obj = []
difference = 0
for raw_item in ner_values:
raw_ner = raw_item['value']
ner = re.findall(r'"(.*?)"', raw_ner)[0]
if ner == ')' or ner == '(' or ner == '/' or ner == 'قانون تغییر' or ner == 'قانون' or ner == '.' or ner == '':
continue
ner_parts = raw_ner.split(ner)[1]
ner_parts = ner_parts.lstrip('"/')
ner_type = ner_parts.strip()
ner_score = raw_item['score'].strip()
ner_type = ner_type.strip()
ner_score = ner_score.strip()
ner = normalize_content(ner)
# پیدا کردن موجودیت نامدار بالا در متن
matched_ner = [(m.start(), m.end()) for m in re.finditer(re.escape(ner), text_temp)]
if matched_ner:
matched_ner_start = matched_ner[0][0]
matched_ner_end = matched_ner[0][1]
before_ner_text = ''
if matched_ner_start > 1:
before_ner_text = text_temp[0:matched_ner_start-1]
difference = len(before_ner_text.split())
ner_start_token = difference
ner_end_token = len(ner.split()) + difference
#after_ner_text = text_temp[matched_ner_end:]
#$after_ner_text_tokens = [text_tokens[t] for t in range (ner_end_token, len(text_tokens))]
#after_ner_text = ' '.join(after_ner_text_tokens)
if ner_end_token > len(text_tokens):
ner_start_token -= 1
ner_end_token -= 1
ner_tokens = [text_tokens[t] for t in range (ner_start_token,ner_end_token)]
# برای جلوگیری از خطای منطقی در هنگامی که مقدار
# ner
#ما بیشتر از یکبار در متن وجود دارد، موجودیت بررسی شده را با کاراکتر های خنثی جایگزین می کنیم
for t in range (ner_start_token,ner_end_token):
text_tokens[t] = '#####'
text_temp = ' '.join(text_tokens)
text_temp = text_temp.strip()
if matched_ner_start == 0:
difference = len(ner.split())
#region Evaluate NER Format
# law_id = 0
if ner_type == 'HALFREFERENCE':
ner_type = 'H_REF'
if ner_type == 'REFERENCE':
ner_type = 'REF'
if not (ner.strip()).startswith('قانون'):
continue
##################################
# پیدا کردن شناسه متناظر با این قانون
# law_id = find_related_law(ner.strip())
##################################
if ner_type == 'DATE2':# تاریخ در سه توکن
# # در این فرمت از تاریخ اگر تعداد توکن های تاریخ، برابر با سه نباشد، تشخیص ماشین اشتباه بوده
# if len(ner_tokens) != 3:
# continue
date_ner = ' '.join(ner_tokens).strip()
# بررسی فرمت صحیح تاریخ با رگولار
result = separated_date_format_finder(date_ner)
if not result:
continue
if ner_type == 'DATE3':# تاریخ در یک توکن
date_ner = ' '.join(ner_tokens).strip()
# بررسی فرمت صحیح تاریخ با رگولار
result = separated_date_format_finder(date_ner)
if not result:
continue
#endregion
ner_obj.append({
'ner_value' : ner.strip(),
'ner_start_token': ner_start_token,
'ner_end_token' : ner_end_token,
'ner_key' : ner_type.strip(),
'ner_score' : float(ner_score.strip()),
#'ner_tokens' : ner_tokens,
})
# if law_id != 0:
# ner_obj[len(ner_obj)-1]['ner_law_id']= law_id
return ner_obj
def inference_main(trained_model,input_sentence):
try:
proccess_result = True, ''
# if(input_sentence == ''):
# input_sentence = read_file()
input_sentence = normalize_content(input_sentence)
result = []
# if len(input_sentence) > 511 :
# sentence_parts = input_sentence.split('.')
sentence_parts = split_sentence(input_sentence)
for part in sentence_parts:
sentence = Sentence(part)
tagger.predict(sentence)
for span in sentence.get_spans():
result.append(span)
# else:
# sentence = Sentence(input_sentence)
# tagger.predict(sentence)
# for span in sentence.get_spans():
# result.append(span)
final_result = ''
#result_header = 100*'#' + '\n' + 'Model Name: ' + trained_model + '\n' + 'Found Entity Count: ' + str(len(result)) + '\n' + 'inference time: ' + str(datetime.now()) + '\n' + 100*'#'
ner_values = []
if result:
for item in result:
value = item.annotation_layers['ner'][0].labeled_identifier
score = round(item.score, 2)
score = str(score)
final_result = final_result + '\n' + value + ' /%/ ' + score
ner_values.append({
'value':value,
'score':score
})
ner_obj_list = find_ner_values_in_text(input_sentence, ner_values)
except Exception as error:
proccess_result = False , error.args[0]
ner_obj_list = []
return ner_obj_list, input_sentence, proccess_result
# تابع بازگشتی برای تقسیم متن به تکه های کوچکتر از 512 کاراکتر
def split_sentence(input_sentence):
# تعریف یک لیست داخلی برای نگهداری بخش‌های تقسیم شده
parts = []
# کاراکترهایی که بر اساس آنها به ترتیب، یک متن را به زیرمتن های کوچک تر تبدیل می کنیم
separators = ['\n', '.', ':', '،']
# تابع بازگشتی
def recursive_split(sentence):
2024-12-01 15:03:40 +00:00
# اگر تعداد توکن های متن پاس داده شده کمتر یا برابر با 511 کاراکتر باشد، آن را به لیست اضافه کن
if len(sentence.split()) <= 256:
2024-09-18 16:35:06 +00:00
if sentence != '':
parts.append(sentence)
return
# تلاش برای استفاده از جداکننده‌های مختلف
for separator in separators:
if separator in sentence:
# تقسیم رشته با استفاده از جداکننده‌ی فعلی
split_parts = sentence.split(separator)
new_sentence = []
for part in split_parts:
new_sentence.append(part)
# بررسی اینکه آیا همه بخش‌ها به اندازه کافی کوچک شده‌اند
for part in new_sentence:
# print(len(part))
2024-12-01 15:03:40 +00:00
if len(part.split()) <= 256:
2024-09-18 16:35:06 +00:00
if part == '':
continue
parts.append(part)
else:
recursive_split(part)
return
# اگر هیچ جداکننده‌ای کار نکرد، رشته را به دو نیمه تقسیم کن
# mid_point = len(sentence) // 2
# recursive_split(sentence[:mid_point])
# recursive_split(sentence[mid_point:])
# شروع تقسیم بازگشتی
recursive_split(input_sentence)
return parts