Flair_NER/train_log_plotter.py

73 lines
2.4 KiB
Python

import matplotlib.pyplot as plt
import os
import csv
def find_newest_file(directory):
# دریافت لیست فایل‌ها در دایرکتوری
files = os.listdir(directory)
# بررسی اینکه آیا دایرکتوری خالی است یا خیر
if not files:
return None # اگر دایرکتوری خالی باشد، مقدار None بازگردانده می‌شود
# ایجاد مسیر کامل برای فایل‌ها و پیدا کردن فایل جدیدتر با استفاده از max
full_paths = [os.path.join(directory, file) for file in files]
newest_file = max(full_paths, key=os.path.getctime) # بر اساس زمان ایجاد (creation time)
return newest_file
def save_diagram(progress_data, model_title):
# آرایه داده‌هایی که قرار است در دیاگرام ترسیم شود
data = progress_data
# استخراج مقادیر x و y از آرایه
x = [int(point[0]) for point in data]
y = [float(point[1]) for point in data]
# ترسیم نمودار
plt.figure(figsize=(8, 6)) # تنظیم اندازه نمودار
plt.plot(x, y, marker='', linestyle='-', color='b', label='Data Line') # ترسیم خط همراه با نقاط
# تنظیم عنوان و برچسب‌های محور
plt.title("Loss Diagram in Train Process", fontsize=14)
plt.xlabel("EPOCHS", fontsize=12)
plt.ylabel("LOSS", fontsize=12)
# نمایش خطوط شبکه
plt.grid(True, linestyle='--', alpha=0.7)
# نمایش legend
plt.legend()
plt.savefig(f"./taggers/{model_title}/loss-diagram.png", dpi=300, bbox_inches='tight')
plt.close()
def read_log(file_path):
# read loss file
with open(file_path, mode="r") as file:
tsv_reader = csv.reader(file, delimiter="\t")
progress = []
# iterate each line
for i,row in enumerate(tsv_reader):
if i == 0:
continue
epoch = row[0]
loss = row[3]
progress.append((epoch, loss))
return progress
def plot_diagram(model_title):
loss_log = f'./taggers/{model_title}/loss.tsv'
progress = read_log(loss_log)
save_diagram(progress, model_title)
print('loss diagram saved!')
if __name__ == "__main__":
model_title = 'HooshvareLab--distilbert-fa-zwnj-base-ner--2025-7-20--3-41-58'
plot_diagram(model_title)