Flair_NER/inference.py

323 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""Flair_NER_Inference .ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1e-Q1bzMvm1mtRuxwnZBeXRfb-E39hxKu
"""
from general_functions import normalize_content
from funcs import separated_date_format_finder
from flair.data import Sentence
from flair.models import SequenceTagger
from flair.nn import Classifier
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from transformers import AutoTokenizer
from flair.embeddings import TransformerWordEmbeddings
from find_law import find_related_law
# from train import model
trained_model = 'unknown'
model = "./taggers/final-model.pt"
print('model read')
tagger = SequenceTagger.load(model)
print('tagger initialized')
def save_to_file(result):
with open('./data/test_result.txt', 'a+', encoding='utf-8') as file:
previous_result = ''
try:
previous_result = file.read()
except:
pass
file.write(previous_result
+ '\n' + 50*'*'
+ '\n' + result
+ '\n' + 50*'*' + '\n')
def read_file():
with open('./data/law.txt', 'r', encoding='utf-8') as file:
text = ''
try:
text = str(file.read())
except:
pass
return text
# import nltk
import re
# from num2words import num2words
def extract_quoted_values(content):
ner_vlaue = re.findall(r'"(.*?)"', content)
return ner_vlaue
def convert_persian_numbers_to_english(text):
persian_numbers = {'۰': '0', '۱': '1', '۲': '2', '۳': '3', '۴': '4', '۵': '5', '۶': '6', '۷': '7', '۸': '8', '۹': '9'}
for persian, english in persian_numbers.items():
text = text.replace(persian, english)
return text
def convert_numbers_to_persian(text):
persian_numbers = {'0': '۰', '1': '۱', '2': '۲', '3': '۳', '4': '۴', '5': '۵', '6': '۶', '7': '۷', '8': '۸', '9': '۹'}
for english, persian in persian_numbers.items():
text = text.replace(english, persian)
return text
def find_quoted_values_in_text(text, ner_values):
"""
این تابع مقادیر استخراج شده را در یک متن دیگر جستجو می‌کند و اگر یافت شدند، محل آنها را نشان می‌دهد.
:param text: متنی که می‌خواهید در آن جستجو انجام شود.
:param quoted_values: لیستی از مقادیر استخراج شده که می‌خواهید در متن جستجو کنید.
"""
#tokens = nltk.word_tokenize(text)
tokens = text.split()
ner_token_index_list = []
for value in ner_values:
#value_tokens = nltk.word_tokenize(value)
# print(value[0])
value_tokens = value[0].split()
value_len = len(value_tokens)
found = False
for i in range(len(tokens) - value_len + 1):
if tokens[i:i + value_len] == value_tokens:
print(f'{i + 1}:{i + value_len}|{value[0]}')
ner_token_index_list.append(f'{i + 1}:{i + value_len}|{value[0]}')
found = True
break
if not found:
print(f'"{value[0]}" در جمله یافت نشد.')
return ner_token_index_list
def find_ner_values_in_text(text, ner_values):
text_temp = text
text_tokens = text.split()
ner_obj = []
difference = 0
for raw_item in ner_values:
raw_ner = raw_item['value']
ner = re.findall(r'"(.*?)"', raw_ner)[0]
# جلوگیری از مقادیر نامعتبر در خروجی
if ner == ')' or ner == '(' or ner == '/' or ner == 'قانون تغییر' or ner == 'قانون':
continue
ner_parts = raw_ner.split(ner)[1]
ner_parts = ner_parts.lstrip('"/')
ner_type = ner_parts.strip()
ner_score = raw_item['score'].strip()
ner_type = ner_type.strip()
ner_score = ner_score.strip()
ner = normalize_content(ner)
# پیدا کردن موجودیت نامدار بالا در متن
matched_ner = [(m.start(), m.end()) for m in re.finditer(re.escape(ner), text_temp)]
if matched_ner:
matched_ner_start = matched_ner[0][0]
matched_ner_end = matched_ner[0][1]
before_ner_text = ''
if matched_ner_start > 1:
before_ner_text = text_temp[0:matched_ner_start-1]
difference = len(before_ner_text.split())
ner_start_token = difference
ner_end_token = len(ner.split()) + difference
#after_ner_text = text_temp[matched_ner_end:]
#$after_ner_text_tokens = [text_tokens[t] for t in range (ner_end_token, len(text_tokens))]
#after_ner_text = ' '.join(after_ner_text_tokens)
ner_tokens = [text_tokens[t] for t in range (ner_start_token,ner_end_token)]
# برای جلوگیری از خطای منطقی در هنگامی که مقدار
# ner
#ما بیشتر از یکبار در متن وجود دارد، موجودیت بررسی شده را با کاراکتر های خنثی جایگزین می کنیم
for t in range (ner_start_token,ner_end_token):
text_tokens[t] = '#####'
text_temp = ' '.join(text_tokens)
text_temp = text_temp.strip()
if matched_ner_start == 0:
difference = len(ner.split())
#region Evaluate NER Format
# law_id = 0
if ner_type == 'HALFREFERENCE':
ner_type = 'H_REF'
if ner_type == 'REFERENCE':
ner_type = 'REF'
if not (ner.strip()).startswith('قانون'):
continue
##################################
# پیدا کردن شناسه متناظر با این قانون
# law_id = find_related_law(ner.strip())
##################################
if ner_type == 'DATE2':# تاریخ در سه توکن
# # در این فرمت از تاریخ اگر تعداد توکن های تاریخ، برابر با سه نباشد، تشخیص ماشین اشتباه بوده
# if len(ner_tokens) != 3:
# continue
date_ner = ' '.join(ner_tokens).strip()
# بررسی فرمت صحیح تاریخ با رگولار
result = separated_date_format_finder(date_ner)
if not result:
continue
if ner_type == 'DATE3':# تاریخ در یک توکن
date_ner = ' '.join(ner_tokens).strip()
# بررسی فرمت صحیح تاریخ با رگولار
result = separated_date_format_finder(date_ner)
if not result:
continue
#endregion
ner_obj.append({
'ner_value' : ner.strip(),
'ner_start_token': ner_start_token,
'ner_end_token' : ner_end_token,
# 'ner_tokens' : ner_tokens,
'ner_key' : ner_type.strip(),
'ner_score' : ner_score.strip()
})
# if law_id != 0:
# ner_obj[len(ner_obj)-1]['ner_law_id']= law_id
return ner_obj
def ner_values_token_finder(ner_values):
# ner_value_list = []
# for ner_value in ner_values:
# ner_value_list.append(extract_quoted_values(ner_value['value']))
text = read_file()
text = normalize_content(text)
# تبدیل اعداد فارسی به اعداد انگلیسی
#text = convert_persian_numbers_to_english(text)
return find_ner_values_in_text(text, ner_values)
def inference_main(trained_model,input_sentence):
if(input_sentence == ''):
input_sentence = read_file()
input_sentence = normalize_content(input_sentence)
# p = len(input_sentence)
# print(p)
# # نام مدل
# model_name = "./data/xlm-roberta-base.pt"
# # بارگذاری توکنایزر
# tokenizer = AutoTokenizer.from_pretrained(model)
# # جمله مورد نظر
# sentence = input_sentence
# # توکن‌سازی جمله
# inputs = tokenizer(sentence, return_tensors="pt")
# from transformers import AutoModel
# # بارگذاری مدل
# model = AutoModel.from_pretrained(model_name)
# # اجرای مدل
# outputs = model(**inputs)
# hidden_states = outputs.last_hidden_state
#model = "./data/final-model.pt"
# model = "./data/final-model_01.pt"
'''model = "./taggers/final-model.pt"
# tokenizer = AutoTokenizer.from_pretrained('./taggers')
# tokens = tokenizer.tokenize(input_sentence)
# for index, token in enumerate(tokens):
# print(index + ' - ' + token)
#model = "/home/gpu/tnlp/jokar/Models/HooshvareLab--bert-base-parsbert-ner-uncased/train 01/final-model.pt"
# model = "./data/HooshvareLab--distilbert-fa-zwnj-base-ner"
#model = "./jokar/Models/HooshvareLab-bert-fa-base-uncased-finetuned-2-pt"
print('model read')
# embeddings = TransformerWordEmbeddings(allow_long_sentences=True )
# tagger = SequenceTagger(embeddings=embeddings)
tagger = SequenceTagger.load(model)
print('tagger initialized')'''
# tokenizer = AutoTokenizer.from_pretrained('./data')
# model = AutoModelForTokenClassification.from_pretrained("./data", num_labels=5)
result = []
# if(len(input_sentence>511)):
# sentence = Sentence(input_sentence)
# tagger.predict(sentence)
# for span in sentence.get_spans():
# result.append(span)
if len(input_sentence) > 511 :
sentence_parts = input_sentence.split('.')
for part in sentence_parts:
sentence = Sentence(part)
tagger.predict(sentence)
for span in sentence.get_spans():
result.append(span)
else:
sentence = Sentence(input_sentence)
tagger.predict(sentence)
for span in sentence.get_spans():
result.append(span)
from datetime import datetime
final_result = ''
result_header = 100*'#' + '\n' + 'Model Name: ' + trained_model + '\n' + 'Found Entity Count: ' + str(len(result)) + '\n' + 'inference time: ' + str(datetime.now()) + '\n' + 100*'#'
ner_values = []
if result:
for item in result:
value = item.annotation_layers['ner'][0].labeled_identifier
score = round(item.score, 2)
score = str(score)
final_result = final_result + '\n' + value + ' /%/ ' + score
ner_values.append({
'value':value,
'score':score
})
# text = read_file()
text = input_sentence
text = normalize_content(text)
ner_obj_list = find_ner_values_in_text(text, ner_values)
ner_addresses = ''
for ner_val in ner_obj_list:
ner_addresses = ner_addresses + '\n' + str(ner_val)
ner_addresss = '\n'+ '$'*70 + '\n' + ner_addresses + '\n' + '$'*70
save_to_file(result_header + final_result + ner_addresss)
return final_result
# tagger: SequenceTagger = SequenceTagger.load("./data/final-model.pt")
# # tagger: SequenceTagger = SequenceTagger.load(model)
# tagger: SequenceTagger = SequenceTagger.load("xlm-roberta-base")
# # sentence = Sentence(input_sentence)
# load the NER tagger
# tagger = Classifier.load('ner')
# tagger : Classifier = Classifier.load("sentiment")
#tagger = Classifier.load('sentiment')
# # tagger.predict(sentence)
# # for span in sentence.get_spans():
# # print(span)
# print()
# print(' <--- Predict Operation Finished! ---> ')
# print()
if __name__ == "__main__":
text = """ماده ۹ - کلیه دستگاههای اجرائی موضوع ماده (۵) قانون مدیریت خدمات کشوری مصوب 8 /7 /1386 با اصلاحات و الحاقات بعدی و ماده (۵) قانون محاسبات عمومی کشور مصوب 1 /6 /1366 با اصلاحات و الحاقات بعدی و نیروهای مسلح جمهوری اسلامی ایران (موضوع ماده (۲) قانون استخدام نیروی انتظامی جمهوری اسلامی ایران مصوب 20 /12 /1382 با اصلاحات و الحاقات بعدی)، مکلفند عوارض و بهای خدمات شهرداری ها و دهیاری های موضوع این قانون را همه ساله حداکثر تا پایان سال مالی به شهرداری یا دهیاری مربوط واریز کنند. ذی حساب و رئیس دستگاه مربوط، در تاریخ ۱۴۰۱/۵/۳ مسؤول حسن اجرای قانون خانواده و جوانی جمعیت که در مهرماه سال 1401 تصویب شده می باشند.
در روز نوزدهم دی ماه سال یکهزار وسیصد و نود و سه برای اولین بار مسئله جمعیت به صورت جدی مورد مطالعه شد. در مورخه 1314.1.17 نیز این مسئله توسط مجلس ملی به صحن آورده شد. هم چنین شورای نگهبان و سازمان محیط زیست و نیز سازمان جوانان هلال احمر در مورد قانون صیانت از کاربران در فضای مجازی با توجه به قانون هوای پاک که در مجلس شورای اسلامی و سازمان محیط زیست به تصویب رسیده مسئول هستند.
این قوانین توسط محمد سرمدی برای شهرهای مشهد و سمنان پیشنهاد داده شد. اعتبار لازم برای اجرای این طرح، بالغ بر سه هزار میلیارد تومان برآورد شده که 30 درصد از آن معادل یک ملیارد تومان در هفته اول پس از تصویب، واریز خواهد شد.
"""
print("do inf ... ")
inference_main('pourmand1376/NER_Farsi',text)