234 lines
9.6 KiB
Python
234 lines
9.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Flair_NER_Inference .ipynb
|
|
|
|
Automatically generated by Colaboratory.
|
|
|
|
Original file is located at
|
|
https://colab.research.google.com/drive/1e-Q1bzMvm1mtRuxwnZBeXRfb-E39hxKu
|
|
"""
|
|
from general_functions import normalize_content
|
|
from funcs import separated_date_format_finder
|
|
from flair.data import Sentence
|
|
from flair.models import SequenceTagger
|
|
from flair.nn import Classifier
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
from flair.embeddings import TransformerWordEmbeddings
|
|
# from find_law import find_related_law
|
|
#from datetime import datetime
|
|
|
|
# from train import model
|
|
trained_model = 'unknown'
|
|
|
|
#model = "./taggers/final-model.pt"
|
|
model = "/home/gpu/tnlp/jokar/Flair_NER/taggers/final-model.pt"
|
|
print('model read')
|
|
tagger = SequenceTagger.load(model)
|
|
print('tagger initialized')
|
|
|
|
def save_to_file(result):
|
|
with open('./data/test_result.txt', 'a+', encoding='utf-8') as file:
|
|
previous_result = ''
|
|
try:
|
|
previous_result = file.read()
|
|
except:
|
|
pass
|
|
file.write(previous_result
|
|
+ '\n' + 50*'*'
|
|
+ '\n' + result
|
|
+ '\n' + 50*'*' + '\n')
|
|
|
|
def read_file():
|
|
with open('./data/law.txt', 'r', encoding='utf-8') as file:
|
|
text = ''
|
|
try:
|
|
text = str(file.read())
|
|
except:
|
|
pass
|
|
return text
|
|
|
|
|
|
# import nltk
|
|
import re
|
|
# from num2words import num2words
|
|
|
|
def find_ner_values_in_text(text, ner_values):
|
|
text_temp = text
|
|
text_tokens = text.split()
|
|
ner_obj = []
|
|
difference = 0
|
|
for raw_item in ner_values:
|
|
raw_ner = raw_item['value']
|
|
ner = re.findall(r'"(.*?)"', raw_ner)[0]
|
|
if ner == ')' or ner == '(' or ner == '/' or ner == 'قانون تغییر' or ner == 'قانون' or ner == '.' or ner == '':
|
|
continue
|
|
ner_parts = raw_ner.split(ner)[1]
|
|
ner_parts = ner_parts.lstrip('"/')
|
|
ner_type = ner_parts.strip()
|
|
ner_score = raw_item['score'].strip()
|
|
ner_type = ner_type.strip()
|
|
ner_score = ner_score.strip()
|
|
ner = normalize_content(ner)
|
|
# پیدا کردن موجودیت نامدار بالا در متن
|
|
matched_ner = [(m.start(), m.end()) for m in re.finditer(re.escape(ner), text_temp)]
|
|
if matched_ner:
|
|
matched_ner_start = matched_ner[0][0]
|
|
matched_ner_end = matched_ner[0][1]
|
|
before_ner_text = ''
|
|
|
|
if matched_ner_start > 1:
|
|
before_ner_text = text_temp[0:matched_ner_start-1]
|
|
difference = len(before_ner_text.split())
|
|
|
|
ner_start_token = difference
|
|
ner_end_token = len(ner.split()) + difference
|
|
#after_ner_text = text_temp[matched_ner_end:]
|
|
#$after_ner_text_tokens = [text_tokens[t] for t in range (ner_end_token, len(text_tokens))]
|
|
#after_ner_text = ' '.join(after_ner_text_tokens)
|
|
if ner_end_token > len(text_tokens):
|
|
ner_start_token -= 1
|
|
ner_end_token -= 1
|
|
ner_tokens = [text_tokens[t] for t in range (ner_start_token,ner_end_token)]
|
|
|
|
# برای جلوگیری از خطای منطقی در هنگامی که مقدار
|
|
# ner
|
|
#ما بیشتر از یکبار در متن وجود دارد، موجودیت بررسی شده را با کاراکتر های خنثی جایگزین می کنیم
|
|
for t in range (ner_start_token,ner_end_token):
|
|
text_tokens[t] = '#####'
|
|
text_temp = ' '.join(text_tokens)
|
|
text_temp = text_temp.strip()
|
|
if matched_ner_start == 0:
|
|
difference = len(ner.split())
|
|
#region Evaluate NER Format
|
|
# law_id = 0
|
|
if ner_type == 'HALFREFERENCE':
|
|
ner_type = 'H_REF'
|
|
if ner_type == 'REFERENCE':
|
|
ner_type = 'REF'
|
|
if not (ner.strip()).startswith('قانون'):
|
|
continue
|
|
##################################
|
|
# پیدا کردن شناسه متناظر با این قانون
|
|
# law_id = find_related_law(ner.strip())
|
|
##################################
|
|
if ner_type == 'DATE2':# تاریخ در سه توکن
|
|
# # در این فرمت از تاریخ اگر تعداد توکن های تاریخ، برابر با سه نباشد، تشخیص ماشین اشتباه بوده
|
|
# if len(ner_tokens) != 3:
|
|
# continue
|
|
date_ner = ' '.join(ner_tokens).strip()
|
|
# بررسی فرمت صحیح تاریخ با رگولار
|
|
result = separated_date_format_finder(date_ner)
|
|
if not result:
|
|
continue
|
|
if ner_type == 'DATE3':# تاریخ در یک توکن
|
|
date_ner = ' '.join(ner_tokens).strip()
|
|
# بررسی فرمت صحیح تاریخ با رگولار
|
|
result = separated_date_format_finder(date_ner)
|
|
if not result:
|
|
continue
|
|
#endregion
|
|
ner_obj.append({
|
|
'ner_value' : ner.strip(),
|
|
'ner_start_token': ner_start_token,
|
|
'ner_end_token' : ner_end_token,
|
|
'ner_key' : ner_type.strip(),
|
|
'ner_score' : float(ner_score.strip()),
|
|
#'ner_tokens' : ner_tokens,
|
|
})
|
|
# if law_id != 0:
|
|
# ner_obj[len(ner_obj)-1]['ner_law_id']= law_id
|
|
|
|
return ner_obj
|
|
|
|
|
|
def inference_main(trained_model,input_sentence):
|
|
try:
|
|
proccess_result = True, ''
|
|
# if(input_sentence == ''):
|
|
# input_sentence = read_file()
|
|
input_sentence = normalize_content(input_sentence)
|
|
result = []
|
|
|
|
# if len(input_sentence) > 511 :
|
|
# sentence_parts = input_sentence.split('.')
|
|
sentence_parts = split_sentence(input_sentence)
|
|
for part in sentence_parts:
|
|
sentence = Sentence(part)
|
|
tagger.predict(sentence)
|
|
for span in sentence.get_spans():
|
|
result.append(span)
|
|
# else:
|
|
# sentence = Sentence(input_sentence)
|
|
# tagger.predict(sentence)
|
|
# for span in sentence.get_spans():
|
|
# result.append(span)
|
|
|
|
final_result = ''
|
|
#result_header = 100*'#' + '\n' + 'Model Name: ' + trained_model + '\n' + 'Found Entity Count: ' + str(len(result)) + '\n' + 'inference time: ' + str(datetime.now()) + '\n' + 100*'#'
|
|
ner_values = []
|
|
if result:
|
|
for item in result:
|
|
value = item.annotation_layers['ner'][0].labeled_identifier
|
|
score = round(item.score, 2)
|
|
score = str(score)
|
|
final_result = final_result + '\n' + value + ' /%/ ' + score
|
|
ner_values.append({
|
|
'value':value,
|
|
'score':score
|
|
})
|
|
|
|
ner_obj_list = find_ner_values_in_text(input_sentence, ner_values)
|
|
except Exception as error:
|
|
proccess_result = False , error.args[0]
|
|
ner_obj_list = []
|
|
return ner_obj_list, input_sentence, proccess_result
|
|
|
|
# تابع بازگشتی برای تقسیم متن به تکه های کوچکتر از 512 کاراکتر
|
|
def split_sentence(input_sentence):
|
|
# تعریف یک لیست داخلی برای نگهداری بخشهای تقسیم شده
|
|
parts = []
|
|
# کاراکترهایی که بر اساس آنها به ترتیب، یک متن را به زیرمتن های کوچک تر تبدیل می کنیم
|
|
separators = ['\n', '.', ':', '،']
|
|
|
|
# تابع بازگشتی
|
|
def recursive_split(sentence):
|
|
# اگر طول جمله کمتر یا برابر با 511 کاراکتر باشد، آن را به لیست اضافه کن
|
|
if len(sentence) <= 511:
|
|
if sentence != '':
|
|
parts.append(sentence)
|
|
return
|
|
|
|
# تلاش برای استفاده از جداکنندههای مختلف
|
|
for separator in separators:
|
|
if separator in sentence:
|
|
# تقسیم رشته با استفاده از جداکنندهی فعلی
|
|
split_parts = sentence.split(separator)
|
|
new_sentence = []
|
|
|
|
for part in split_parts:
|
|
new_sentence.append(part)
|
|
|
|
# بررسی اینکه آیا همه بخشها به اندازه کافی کوچک شدهاند
|
|
|
|
for part in new_sentence:
|
|
# print(len(part))
|
|
if len(part) <= 511:
|
|
if part == '':
|
|
continue
|
|
parts.append(part)
|
|
else:
|
|
recursive_split(part)
|
|
return
|
|
|
|
# اگر هیچ جداکنندهای کار نکرد، رشته را به دو نیمه تقسیم کن
|
|
# mid_point = len(sentence) // 2
|
|
# recursive_split(sentence[:mid_point])
|
|
# recursive_split(sentence[mid_point:])
|
|
|
|
# شروع تقسیم بازگشتی
|
|
recursive_split(input_sentence)
|
|
|
|
return parts
|