NLP_tutorial/3-NLP_services/src/Multi2OIE/train.py
2025-04-09 09:39:40 +03:30

80 lines
2.9 KiB
Python

import os
import torch
import torch.nn as nn
import utils.bio as bio
from tqdm import tqdm
from extract import extract
from utils import utils
from src.Multi2OIE.test import do_eval
def train(args,
epoch,
model,
trn_loader,
dev_loaders,
summarizer,
optimizer,
scheduler):
total_pred_loss, total_arg_loss, trn_results = 0, 0, None
epoch_steps = int(args.total_steps / args.epochs)
iterator = tqdm(enumerate(trn_loader), desc='steps', total=epoch_steps)
for step, batch in iterator:
batch = map(lambda x: x.to(args.device), batch)
token_ids, att_mask, single_pred_label, single_arg_label, all_pred_label = batch
pred_mask = bio.get_pred_mask(single_pred_label)
model.train()
model.zero_grad()
# feed to predicate model
batch_loss, pred_loss, arg_loss = model(
input_ids=token_ids,
attention_mask=att_mask,
predicate_mask=pred_mask,
total_pred_labels=all_pred_label,
arg_labels=single_arg_label)
# get performance on this batch
total_pred_loss += pred_loss.item()
total_arg_loss += arg_loss.item()
batch_loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
trn_results = [total_pred_loss / (step + 1), total_arg_loss / (step + 1)]
if step > epoch_steps:
break
# interim evaluation
if step % 1000 == 0 and step != 0:
dev_iter = zip(args.dev_data_path, args.dev_gold_path, dev_loaders)
dev_results = list()
total_sum = 0
for dev_input, dev_gold, dev_loader in dev_iter:
dev_name = dev_input.split('/')[-1].replace('.pkl', '')
output_path = os.path.join(args.save_path, f'epoch{epoch}_dev/step{step}/{dev_name}')
extract(args, model, dev_loader, output_path)
dev_result = do_eval(output_path, dev_gold)
utils.print_results(f"EPOCH{epoch} STEP{step} EVAL",
dev_result, ["F1 ", "PREC", "REC ", "AUC "])
total_sum += dev_result[0] + dev_result[-1]
dev_result.append(dev_result[0] + dev_result[-1])
dev_results += dev_result
summarizer.save_results([step] + trn_results + dev_results + [total_sum])
model_name = utils.set_model_name(total_sum, epoch, step)
torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
if step % args.summary_step == 0 and step != 0:
utils.print_results(f"EPOCH{epoch} STEP{step} TRAIN",
trn_results, ["PRED LOSS", "ARG LOSS "])
# end epoch summary
utils.print_results(f"EPOCH{epoch} TRAIN",
trn_results, ["PRED LOSS", "ARG LOSS "])
return trn_results