diff --git a/train.py b/train.py index e8d91b7..3a74a71 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,7 @@ -learning_rate = 0.65e-4 # 0.65e-4 - 0.4e-4 -mini_batch_size = 8 -max_epochs = 100 +LEARNING_RATE = 0.65e-4 # 0.65e-4 - 0.4e-4 +MINI_BATCH_SIZE = 8 +MAX_EPOCHS = 100 -from funcs import save_to_file_by_address -import json -import os import datetime from pathlib import Path from flair.data import Corpus @@ -15,23 +12,9 @@ from flair.trainers import ModelTrainer from flair.models import SequenceTagger from flair.embeddings import TransformerDocumentEmbeddings -#model = os.getcwd() + "\\data\\final-model.pt" -#model = os.getcwd() + "/data/HooshvareLab--distilbert-fa-zwnj-base-ner" # مدل اولیه که تست شد و تا حدود 70 درصد در آخرین آموزش خوب جواب می داد -#model = os.getcwd() + "/data/distilbert-base-multilingual-cased-tavasi" -# model = "HooshvareLab/bert-fa-base-uncased-ner-peyma" -# model = "PooryaPiroozfar/Flair-Persian-NER" # 111111111111111 - -## --------------------------------------------------------- -## --- آخرین کار مورد استفاده در سامانه قانون یار از این آموزش دیده است -#model = "orgcatorg/xlm-v-base-ner" # بهترین توکنایزر فارسی *********************** -## --------------------------------------------------------- -# model = AutoModel.from_pretrained("/home/gpu/HFHOME/hub/models--orgcatorg--xlm-v-base-ner") - -#model = "pourmand1376/NER_Farsi" # -#model = "HooshvareLab/bert-base-parsbert-ner-uncased" # **** خوب جواب داد -#model = "SeyedAli/Persian-Text-NER-Bert-V1" # ***** خیلی خوب جواب داد -#model = "HooshvareLab/bert-base-parsbert-peymaner-uncased" # جالب نبود! -#model = "HooshvareLab/bert-base-parsbert-armanner-uncased" # جالب نبود! +def save_to_file_by_address(file_address, content): + with open(file_address, 'a+', encoding='utf-8') as file: + file.write(content) def digit_correct(input_num): if input_num <10: @@ -47,7 +30,9 @@ def main_train(model): """ time = datetime.datetime.now() + # ایجاد فرمتی برای نام مدل نهایی که با کمک تاریخ روز ساخته می شود model_title = f"{time.year}-{digit_correct(time.month)}-{digit_correct(time.day)}--{digit_correct(time.hour)}-{digit_correct(time.minute)}-{digit_correct(time.second)}--{model}".replace('/','--') + print(f'\nMODEL:: {model}\n') # define dataset columns @@ -107,9 +92,9 @@ def main_train(model): # begin training data try: result = trainer.fine_tune(f"./taggers/{model_title}", - learning_rate= learning_rate, - mini_batch_size= mini_batch_size, - max_epochs= max_epochs + learning_rate= LEARNING_RATE, + mini_batch_size= MINI_BATCH_SIZE, + max_epochs= MAX_EPOCHS ) except Exception as e: print(str(e.args[0])) @@ -119,11 +104,12 @@ def main_train(model): # plot training log to evaluate process try: from train_log_plotter import plot_diagram - plot_diagram(model_title) + result = plot_diagram(model_title) + print(result[1]) + except: print('log diagram failed due to error!') - print('fine-tune operation finished') operation_time = datetime.datetime.now() @@ -157,7 +143,7 @@ def main_train(model): F1 Score: {result} ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n''' - hyperparameters = f"""learning_rate: {learning_rate} - mini_batch_size: {mini_batch_size} - max_epochs: {max_epochs}""" + hyperparameters = f"""LEARNING_RATE: {LEARNING_RATE} - MINI_BATCH_SIZE: {MINI_BATCH_SIZE} - MAX_EPOCHS: {MAX_EPOCHS}""" final_result = f"""Model Name: {model} Fine-Tune Parameters: {hyperparameters} @@ -171,29 +157,15 @@ def main_train(model): return True -models = """ -HooshvareLab/bert-base-parsbert-ner-uncased -HooshvareLab/bert-fa-base-uncased-ner-peyma -HooshvareLab/bert-base-parsbert-armanner-uncased -HooshvareLab/bert-fa-base-uncased-ner-arman -HooshvareLab/bert-base-parsbert-peymaner-uncased -""" -models = """ -HooshvareLab/bert-fa-base-uncased-ner-peyma -""" -# HooshvareLab/distilbert-fa-zwnj-base-ner -models_with_error= """ -nicolauduran45/affilgood-ner-multilingual-v2 - error -Amirmerfan/bert-base-uncased-persian-ner-50k-base - error -AliFartout/Roberta-fa-en-ner - error -""" -model = 'HooshvareLab/bert-fa-base-uncased-ner-peyma' + +models = ["HooshvareLab/bert-fa-base-uncased-ner-peyma"] + if __name__ == "__main__": # model = 'HooshvareLab/bert-fa-base-uncased-ner-peyma' # main_train(model) # iterate models to train - for model in models.split('\n'): + for model in models: if model == '': continue diff --git a/train_comments.txt b/train_comments.txt new file mode 100644 index 0000000..5f0725b --- /dev/null +++ b/train_comments.txt @@ -0,0 +1,32 @@ +#model = os.getcwd() + "\\data\\final-model.pt" +#model = os.getcwd() + "/data/HooshvareLab--distilbert-fa-zwnj-base-ner" # مدل اولیه که تست شد و تا حدود 70 درصد در آخرین آموزش خوب جواب می داد +#model = os.getcwd() + "/data/distilbert-base-multilingual-cased-tavasi" +# model = "HooshvareLab/bert-fa-base-uncased-ner-peyma" +# model = "PooryaPiroozfar/Flair-Persian-NER" # 111111111111111 + +## --------------------------------------------------------- +## --- آخرین کار مورد استفاده در سامانه قانون یار از این آموزش دیده است +#model = "orgcatorg/xlm-v-base-ner" # بهترین توکنایزر فارسی *********************** +## --------------------------------------------------------- +# model = AutoModel.from_pretrained("/home/gpu/HFHOME/hub/models--orgcatorg--xlm-v-base-ner") + +#model = "pourmand1376/NER_Farsi" # +#model = "HooshvareLab/bert-base-parsbert-ner-uncased" # **** خوب جواب داد +#model = "SeyedAli/Persian-Text-NER-Bert-V1" # ***** خیلی خوب جواب داد +#model = "HooshvareLab/bert-base-parsbert-peymaner-uncased" # جالب نبود! +#model = "HooshvareLab/bert-base-parsbert-armanner-uncased" # جالب نبود! + +# HooshvareLab/distilbert-fa-zwnj-base-ner +models_with_error= """ +nicolauduran45/affilgood-ner-multilingual-v2 - error +Amirmerfan/bert-base-uncased-persian-ner-50k-base - error +AliFartout/Roberta-fa-en-ner - error +""" + +models = """ +HooshvareLab/bert-base-parsbert-ner-uncased +HooshvareLab/bert-fa-base-uncased-ner-peyma +HooshvareLab/bert-base-parsbert-armanner-uncased +HooshvareLab/bert-fa-base-uncased-ner-arman +HooshvareLab/bert-base-parsbert-peymaner-uncased +""" \ No newline at end of file diff --git a/train_log_plotter.py b/train_log_plotter.py index 8a72c2d..4cc1f4b 100644 --- a/train_log_plotter.py +++ b/train_log_plotter.py @@ -1,20 +1,30 @@ +""" +این سورس جهت خواندن فایل لاگ آموزش مدل و ترسیم نمودار پیشرفت بر اساس فاکتور LOSS ایجاد شده است +""" import matplotlib.pyplot as plt import os import csv -def find_newest_file(directory): - # دریافت لیست فایل‌ها در دایرکتوری - files = os.listdir(directory) - # بررسی اینکه آیا دایرکتوری خالی است یا خیر - if not files: - return None # اگر دایرکتوری خالی باشد، مقدار None بازگردانده می‌شود +# def find_newest_file(directory): +# # دریافت لیست فایل‌ها در دایرکتوری +# files = os.listdir(directory) +# # بررسی اینکه آیا دایرکتوری خالی است یا خیر +# if not files: +# return None # اگر دایرکتوری خالی باشد، مقدار None بازگردانده می‌شود - # ایجاد مسیر کامل برای فایل‌ها و پیدا کردن فایل جدیدتر با استفاده از max - full_paths = [os.path.join(directory, file) for file in files] - newest_file = max(full_paths, key=os.path.getctime) # بر اساس زمان ایجاد (creation time) - return newest_file +# # ایجاد مسیر کامل برای فایل‌ها و پیدا کردن فایل جدیدتر با استفاده از max +# full_paths = [os.path.join(directory, file) for file in files] +# newest_file = max(full_paths, key=os.path.getctime) # بر اساس زمان ایجاد (creation time) +# return newest_file -def save_diagram(progress_data, model_title): +def generate_diagram(progress_data:list, model_title:str): + """ + ایجاد و ذخیره دیاگرام بر اساس داده های مربوط به آموزش مدل + + Args: + progress_data(list): لیست فاکتور loss در ایپاک های آموزش مدل + model_title(str): نام مدل آموزش دیده + """ # آرایه داده‌هایی که قرار است در دیاگرام ترسیم شود data = progress_data @@ -43,7 +53,14 @@ def save_diagram(progress_data, model_title): -def read_log(file_path): +def read_log(file_path:str): + """ + خواندن محتوای فایل loss + + Args: + file_path(str): آدرس فایل loss + progress(list[tuple]): لیستی شامل شماره ایپاک و loss متناظر با آن + """ # read loss file with open(file_path, mode="r") as file: tsv_reader = csv.reader(file, delimiter="\t") @@ -61,12 +78,41 @@ def read_log(file_path): return progress -def plot_diagram(model_title): +def plot_diagram(model_title:str): + """ + ترسیم نمودار بر اساس لاگ فاکتور LOSS + + Args: + model_title(str): نام مدلی که آموزش داده شده است + + Returns: + tuple: شامل یک کلید بولین که وضعیت موفقیت عملیات را نشان می دهد و نیز توضیحاتی در مورد نتیجه عملیات + """ + + result = True, '' + + # تنظیم آدرس فایل loss مربوط به این مدل loss_log = f'./taggers/{model_title}/loss.tsv' - progress = read_log(loss_log) - save_diagram(progress, model_title) + # خواندن فایل loss مربوط به مدل + try: + progress = read_log(loss_log) + except: + result = False, "Error: please check model's path!" + return result + + # ایجاد و ذخیره دیاگرام در پوشه مربوط به مدل + generate_diagram(progress, model_title) print('loss diagram saved!') + + result = True, 'loss diagram generated and saved!' + + return result if __name__ == "__main__": - model_title = 'HooshvareLab--distilbert-fa-zwnj-base-ner--2025-7-20--3-41-58' - plot_diagram(model_title) + # تست سورس بر اساس نام مدل + # چنین مدلی باید در پوشه taggers وجود داشته باشد + model_title = '2025-07-21--17-51-49--HooshvareLab--bert-fa-base-uncased-ner-peyma' + + result = plot_diagram(model_title) + print(result[1]) +