ner documentation
This commit is contained in:
parent
c21d1bd22a
commit
46bced10ec
66
train.py
66
train.py
|
@ -1,10 +1,7 @@
|
||||||
learning_rate = 0.65e-4 # 0.65e-4 - 0.4e-4
|
LEARNING_RATE = 0.65e-4 # 0.65e-4 - 0.4e-4
|
||||||
mini_batch_size = 8
|
MINI_BATCH_SIZE = 8
|
||||||
max_epochs = 100
|
MAX_EPOCHS = 100
|
||||||
|
|
||||||
from funcs import save_to_file_by_address
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from flair.data import Corpus
|
from flair.data import Corpus
|
||||||
|
@ -15,23 +12,9 @@ from flair.trainers import ModelTrainer
|
||||||
from flair.models import SequenceTagger
|
from flair.models import SequenceTagger
|
||||||
from flair.embeddings import TransformerDocumentEmbeddings
|
from flair.embeddings import TransformerDocumentEmbeddings
|
||||||
|
|
||||||
#model = os.getcwd() + "\\data\\final-model.pt"
|
def save_to_file_by_address(file_address, content):
|
||||||
#model = os.getcwd() + "/data/HooshvareLab--distilbert-fa-zwnj-base-ner" # مدل اولیه که تست شد و تا حدود 70 درصد در آخرین آموزش خوب جواب می داد
|
with open(file_address, 'a+', encoding='utf-8') as file:
|
||||||
#model = os.getcwd() + "/data/distilbert-base-multilingual-cased-tavasi"
|
file.write(content)
|
||||||
# model = "HooshvareLab/bert-fa-base-uncased-ner-peyma"
|
|
||||||
# model = "PooryaPiroozfar/Flair-Persian-NER" # 111111111111111
|
|
||||||
|
|
||||||
## ---------------------------------------------------------
|
|
||||||
## --- آخرین کار مورد استفاده در سامانه قانون یار از این آموزش دیده است
|
|
||||||
#model = "orgcatorg/xlm-v-base-ner" # بهترین توکنایزر فارسی ***********************
|
|
||||||
## ---------------------------------------------------------
|
|
||||||
# model = AutoModel.from_pretrained("/home/gpu/HFHOME/hub/models--orgcatorg--xlm-v-base-ner")
|
|
||||||
|
|
||||||
#model = "pourmand1376/NER_Farsi" #
|
|
||||||
#model = "HooshvareLab/bert-base-parsbert-ner-uncased" # **** خوب جواب داد
|
|
||||||
#model = "SeyedAli/Persian-Text-NER-Bert-V1" # ***** خیلی خوب جواب داد
|
|
||||||
#model = "HooshvareLab/bert-base-parsbert-peymaner-uncased" # جالب نبود!
|
|
||||||
#model = "HooshvareLab/bert-base-parsbert-armanner-uncased" # جالب نبود!
|
|
||||||
|
|
||||||
def digit_correct(input_num):
|
def digit_correct(input_num):
|
||||||
if input_num <10:
|
if input_num <10:
|
||||||
|
@ -47,7 +30,9 @@ def main_train(model):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
time = datetime.datetime.now()
|
time = datetime.datetime.now()
|
||||||
|
# ایجاد فرمتی برای نام مدل نهایی که با کمک تاریخ روز ساخته می شود
|
||||||
model_title = f"{time.year}-{digit_correct(time.month)}-{digit_correct(time.day)}--{digit_correct(time.hour)}-{digit_correct(time.minute)}-{digit_correct(time.second)}--{model}".replace('/','--')
|
model_title = f"{time.year}-{digit_correct(time.month)}-{digit_correct(time.day)}--{digit_correct(time.hour)}-{digit_correct(time.minute)}-{digit_correct(time.second)}--{model}".replace('/','--')
|
||||||
|
|
||||||
print(f'\nMODEL:: {model}\n')
|
print(f'\nMODEL:: {model}\n')
|
||||||
|
|
||||||
# define dataset columns
|
# define dataset columns
|
||||||
|
@ -107,9 +92,9 @@ def main_train(model):
|
||||||
# begin training data
|
# begin training data
|
||||||
try:
|
try:
|
||||||
result = trainer.fine_tune(f"./taggers/{model_title}",
|
result = trainer.fine_tune(f"./taggers/{model_title}",
|
||||||
learning_rate= learning_rate,
|
learning_rate= LEARNING_RATE,
|
||||||
mini_batch_size= mini_batch_size,
|
mini_batch_size= MINI_BATCH_SIZE,
|
||||||
max_epochs= max_epochs
|
max_epochs= MAX_EPOCHS
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e.args[0]))
|
print(str(e.args[0]))
|
||||||
|
@ -119,11 +104,12 @@ def main_train(model):
|
||||||
# plot training log to evaluate process
|
# plot training log to evaluate process
|
||||||
try:
|
try:
|
||||||
from train_log_plotter import plot_diagram
|
from train_log_plotter import plot_diagram
|
||||||
plot_diagram(model_title)
|
result = plot_diagram(model_title)
|
||||||
|
print(result[1])
|
||||||
|
|
||||||
except:
|
except:
|
||||||
print('log diagram failed due to error!')
|
print('log diagram failed due to error!')
|
||||||
|
|
||||||
|
|
||||||
print('fine-tune operation finished')
|
print('fine-tune operation finished')
|
||||||
|
|
||||||
operation_time = datetime.datetime.now()
|
operation_time = datetime.datetime.now()
|
||||||
|
@ -157,7 +143,7 @@ def main_train(model):
|
||||||
F1 Score: {result}
|
F1 Score: {result}
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n'''
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n'''
|
||||||
|
|
||||||
hyperparameters = f"""learning_rate: {learning_rate} - mini_batch_size: {mini_batch_size} - max_epochs: {max_epochs}"""
|
hyperparameters = f"""LEARNING_RATE: {LEARNING_RATE} - MINI_BATCH_SIZE: {MINI_BATCH_SIZE} - MAX_EPOCHS: {MAX_EPOCHS}"""
|
||||||
|
|
||||||
final_result = f"""Model Name: {model}
|
final_result = f"""Model Name: {model}
|
||||||
Fine-Tune Parameters: {hyperparameters}
|
Fine-Tune Parameters: {hyperparameters}
|
||||||
|
@ -171,29 +157,15 @@ def main_train(model):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
models = """
|
|
||||||
HooshvareLab/bert-base-parsbert-ner-uncased
|
models = ["HooshvareLab/bert-fa-base-uncased-ner-peyma"]
|
||||||
HooshvareLab/bert-fa-base-uncased-ner-peyma
|
|
||||||
HooshvareLab/bert-base-parsbert-armanner-uncased
|
|
||||||
HooshvareLab/bert-fa-base-uncased-ner-arman
|
|
||||||
HooshvareLab/bert-base-parsbert-peymaner-uncased
|
|
||||||
"""
|
|
||||||
models = """
|
|
||||||
HooshvareLab/bert-fa-base-uncased-ner-peyma
|
|
||||||
"""
|
|
||||||
# HooshvareLab/distilbert-fa-zwnj-base-ner
|
|
||||||
models_with_error= """
|
|
||||||
nicolauduran45/affilgood-ner-multilingual-v2 - error
|
|
||||||
Amirmerfan/bert-base-uncased-persian-ner-50k-base - error
|
|
||||||
AliFartout/Roberta-fa-en-ner - error
|
|
||||||
"""
|
|
||||||
model = 'HooshvareLab/bert-fa-base-uncased-ner-peyma'
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# model = 'HooshvareLab/bert-fa-base-uncased-ner-peyma'
|
# model = 'HooshvareLab/bert-fa-base-uncased-ner-peyma'
|
||||||
# main_train(model)
|
# main_train(model)
|
||||||
|
|
||||||
# iterate models to train
|
# iterate models to train
|
||||||
for model in models.split('\n'):
|
for model in models:
|
||||||
if model == '':
|
if model == '':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
32
train_comments.txt
Normal file
32
train_comments.txt
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
#model = os.getcwd() + "\\data\\final-model.pt"
|
||||||
|
#model = os.getcwd() + "/data/HooshvareLab--distilbert-fa-zwnj-base-ner" # مدل اولیه که تست شد و تا حدود 70 درصد در آخرین آموزش خوب جواب می داد
|
||||||
|
#model = os.getcwd() + "/data/distilbert-base-multilingual-cased-tavasi"
|
||||||
|
# model = "HooshvareLab/bert-fa-base-uncased-ner-peyma"
|
||||||
|
# model = "PooryaPiroozfar/Flair-Persian-NER" # 111111111111111
|
||||||
|
|
||||||
|
## ---------------------------------------------------------
|
||||||
|
## --- آخرین کار مورد استفاده در سامانه قانون یار از این آموزش دیده است
|
||||||
|
#model = "orgcatorg/xlm-v-base-ner" # بهترین توکنایزر فارسی ***********************
|
||||||
|
## ---------------------------------------------------------
|
||||||
|
# model = AutoModel.from_pretrained("/home/gpu/HFHOME/hub/models--orgcatorg--xlm-v-base-ner")
|
||||||
|
|
||||||
|
#model = "pourmand1376/NER_Farsi" #
|
||||||
|
#model = "HooshvareLab/bert-base-parsbert-ner-uncased" # **** خوب جواب داد
|
||||||
|
#model = "SeyedAli/Persian-Text-NER-Bert-V1" # ***** خیلی خوب جواب داد
|
||||||
|
#model = "HooshvareLab/bert-base-parsbert-peymaner-uncased" # جالب نبود!
|
||||||
|
#model = "HooshvareLab/bert-base-parsbert-armanner-uncased" # جالب نبود!
|
||||||
|
|
||||||
|
# HooshvareLab/distilbert-fa-zwnj-base-ner
|
||||||
|
models_with_error= """
|
||||||
|
nicolauduran45/affilgood-ner-multilingual-v2 - error
|
||||||
|
Amirmerfan/bert-base-uncased-persian-ner-50k-base - error
|
||||||
|
AliFartout/Roberta-fa-en-ner - error
|
||||||
|
"""
|
||||||
|
|
||||||
|
models = """
|
||||||
|
HooshvareLab/bert-base-parsbert-ner-uncased
|
||||||
|
HooshvareLab/bert-fa-base-uncased-ner-peyma
|
||||||
|
HooshvareLab/bert-base-parsbert-armanner-uncased
|
||||||
|
HooshvareLab/bert-fa-base-uncased-ner-arman
|
||||||
|
HooshvareLab/bert-base-parsbert-peymaner-uncased
|
||||||
|
"""
|
|
@ -1,20 +1,30 @@
|
||||||
|
"""
|
||||||
|
این سورس جهت خواندن فایل لاگ آموزش مدل و ترسیم نمودار پیشرفت بر اساس فاکتور LOSS ایجاد شده است
|
||||||
|
"""
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import os
|
import os
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
def find_newest_file(directory):
|
# def find_newest_file(directory):
|
||||||
# دریافت لیست فایلها در دایرکتوری
|
# # دریافت لیست فایلها در دایرکتوری
|
||||||
files = os.listdir(directory)
|
# files = os.listdir(directory)
|
||||||
# بررسی اینکه آیا دایرکتوری خالی است یا خیر
|
# # بررسی اینکه آیا دایرکتوری خالی است یا خیر
|
||||||
if not files:
|
# if not files:
|
||||||
return None # اگر دایرکتوری خالی باشد، مقدار None بازگردانده میشود
|
# return None # اگر دایرکتوری خالی باشد، مقدار None بازگردانده میشود
|
||||||
|
|
||||||
# ایجاد مسیر کامل برای فایلها و پیدا کردن فایل جدیدتر با استفاده از max
|
# # ایجاد مسیر کامل برای فایلها و پیدا کردن فایل جدیدتر با استفاده از max
|
||||||
full_paths = [os.path.join(directory, file) for file in files]
|
# full_paths = [os.path.join(directory, file) for file in files]
|
||||||
newest_file = max(full_paths, key=os.path.getctime) # بر اساس زمان ایجاد (creation time)
|
# newest_file = max(full_paths, key=os.path.getctime) # بر اساس زمان ایجاد (creation time)
|
||||||
return newest_file
|
# return newest_file
|
||||||
|
|
||||||
def save_diagram(progress_data, model_title):
|
def generate_diagram(progress_data:list, model_title:str):
|
||||||
|
"""
|
||||||
|
ایجاد و ذخیره دیاگرام بر اساس داده های مربوط به آموزش مدل
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_data(list): لیست فاکتور loss در ایپاک های آموزش مدل
|
||||||
|
model_title(str): نام مدل آموزش دیده
|
||||||
|
"""
|
||||||
# آرایه دادههایی که قرار است در دیاگرام ترسیم شود
|
# آرایه دادههایی که قرار است در دیاگرام ترسیم شود
|
||||||
data = progress_data
|
data = progress_data
|
||||||
|
|
||||||
|
@ -43,7 +53,14 @@ def save_diagram(progress_data, model_title):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_log(file_path):
|
def read_log(file_path:str):
|
||||||
|
"""
|
||||||
|
خواندن محتوای فایل loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path(str): آدرس فایل loss
|
||||||
|
progress(list[tuple]): لیستی شامل شماره ایپاک و loss متناظر با آن
|
||||||
|
"""
|
||||||
# read loss file
|
# read loss file
|
||||||
with open(file_path, mode="r") as file:
|
with open(file_path, mode="r") as file:
|
||||||
tsv_reader = csv.reader(file, delimiter="\t")
|
tsv_reader = csv.reader(file, delimiter="\t")
|
||||||
|
@ -61,12 +78,41 @@ def read_log(file_path):
|
||||||
return progress
|
return progress
|
||||||
|
|
||||||
|
|
||||||
def plot_diagram(model_title):
|
def plot_diagram(model_title:str):
|
||||||
|
"""
|
||||||
|
ترسیم نمودار بر اساس لاگ فاکتور LOSS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_title(str): نام مدلی که آموزش داده شده است
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: شامل یک کلید بولین که وضعیت موفقیت عملیات را نشان می دهد و نیز توضیحاتی در مورد نتیجه عملیات
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = True, ''
|
||||||
|
|
||||||
|
# تنظیم آدرس فایل loss مربوط به این مدل
|
||||||
loss_log = f'./taggers/{model_title}/loss.tsv'
|
loss_log = f'./taggers/{model_title}/loss.tsv'
|
||||||
|
# خواندن فایل loss مربوط به مدل
|
||||||
|
try:
|
||||||
progress = read_log(loss_log)
|
progress = read_log(loss_log)
|
||||||
save_diagram(progress, model_title)
|
except:
|
||||||
|
result = False, "Error: please check model's path!"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ایجاد و ذخیره دیاگرام در پوشه مربوط به مدل
|
||||||
|
generate_diagram(progress, model_title)
|
||||||
print('loss diagram saved!')
|
print('loss diagram saved!')
|
||||||
|
|
||||||
|
result = True, 'loss diagram generated and saved!'
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_title = 'HooshvareLab--distilbert-fa-zwnj-base-ner--2025-7-20--3-41-58'
|
# تست سورس بر اساس نام مدل
|
||||||
plot_diagram(model_title)
|
# چنین مدلی باید در پوشه taggers وجود داشته باشد
|
||||||
|
model_title = '2025-07-21--17-51-49--HooshvareLab--bert-fa-base-uncased-ner-peyma'
|
||||||
|
|
||||||
|
result = plot_diagram(model_title)
|
||||||
|
print(result[1])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user