Flair_NER/train_log_plotter.py
2025-08-17 16:55:36 +03:30

119 lines
4.1 KiB
Python

"""
این سورس جهت خواندن فایل لاگ آموزش مدل و ترسیم نمودار پیشرفت بر اساس فاکتور LOSS ایجاد شده است
"""
import matplotlib.pyplot as plt
import os
import csv
# def find_newest_file(directory):
# # دریافت لیست فایل‌ها در دایرکتوری
# files = os.listdir(directory)
# # بررسی اینکه آیا دایرکتوری خالی است یا خیر
# if not files:
# return None # اگر دایرکتوری خالی باشد، مقدار None بازگردانده می‌شود
# # ایجاد مسیر کامل برای فایل‌ها و پیدا کردن فایل جدیدتر با استفاده از max
# full_paths = [os.path.join(directory, file) for file in files]
# newest_file = max(full_paths, key=os.path.getctime) # بر اساس زمان ایجاد (creation time)
# return newest_file
def generate_diagram(progress_data:list, model_title:str):
"""
ایجاد و ذخیره دیاگرام بر اساس داده های مربوط به آموزش مدل
Args:
progress_data(list): لیست فاکتور loss در ایپاک های آموزش مدل
model_title(str): نام مدل آموزش دیده
"""
# آرایه داده‌هایی که قرار است در دیاگرام ترسیم شود
data = progress_data
# استخراج مقادیر x و y از آرایه
x = [int(point[0]) for point in data]
y = [float(point[1]) for point in data]
# ترسیم نمودار
plt.figure(figsize=(8, 6)) # تنظیم اندازه نمودار
plt.plot(x, y, marker='', linestyle='-', color='b', label='Data Line') # ترسیم خط همراه با نقاط
# تنظیم عنوان و برچسب‌های محور
plt.title("Loss Diagram in Train Process", fontsize=14)
plt.xlabel("EPOCHS", fontsize=12)
plt.ylabel("LOSS", fontsize=12)
# نمایش خطوط شبکه
plt.grid(True, linestyle='--', alpha=0.7)
# نمایش legend
plt.legend()
plt.savefig(f"./taggers/{model_title}/loss-diagram.png", dpi=300, bbox_inches='tight')
plt.close()
def read_log(file_path:str):
"""
خواندن محتوای فایل loss
Args:
file_path(str): آدرس فایل loss
progress(list[tuple]): لیستی شامل شماره ایپاک و loss متناظر با آن
"""
# read loss file
with open(file_path, mode="r") as file:
tsv_reader = csv.reader(file, delimiter="\t")
progress = []
# iterate each line
for i,row in enumerate(tsv_reader):
if i == 0:
continue
epoch = row[0]
loss = row[3]
progress.append((epoch, loss))
return progress
def plot_diagram(model_title:str):
"""
ترسیم نمودار بر اساس لاگ فاکتور LOSS
Args:
model_title(str): نام مدلی که آموزش داده شده است
Returns:
tuple: شامل یک کلید بولین که وضعیت موفقیت عملیات را نشان می دهد و نیز توضیحاتی در مورد نتیجه عملیات
"""
result = True, ''
# تنظیم آدرس فایل loss مربوط به این مدل
loss_log = f'./taggers/{model_title}/loss.tsv'
# خواندن فایل loss مربوط به مدل
try:
progress = read_log(loss_log)
except:
result = False, "Error: please check model's path!"
return result
# ایجاد و ذخیره دیاگرام در پوشه مربوط به مدل
generate_diagram(progress, model_title)
print('loss diagram saved!')
result = True, 'loss diagram generated and saved!'
return result
if __name__ == "__main__":
# تست سورس بر اساس نام مدل
# چنین مدلی باید در پوشه taggers وجود داشته باشد
model_title = '2025-07-21--17-51-49--HooshvareLab--bert-fa-base-uncased-ner-peyma'
result = plot_diagram(model_title)
print(result[1])