Flair_NER/temp.py

67 lines
2.1 KiB
Python

from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import FlairEmbeddings, WordEmbeddings, StackedEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from flair.data import Dictionary
from flair.models import SequenceTagger
# Path to your new dataset
data_folder = './data' # Folder containing your dataset
train_file = 'DATASET140402_no_aref.txt' # qavanin 36K tokens
test_file = 'test_ds_new.txt' # test 110 sections - 6.7K
# Column format for your dataset (adjust as necessary)
# For example: 0 = text, 1 = NER tags
columns = {0: 'text', 1: 'ner'}
# Load the corpus
corpus = ColumnCorpus(
data_folder=data_folder,
column_format=columns,
train_file=train_file,
test_file=test_file
)
print(corpus)
# Define the tag dictionary (new NER tags)
tag_type = 'ner'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# Load the existing model (mdl.pt)
tagger = SequenceTagger.load("/home/gpu/tnlp/jokar/Models/catorg/14030906_before_aitools_ds_finetune/final-model.pt")
# Define embeddings (you can modify this as needed)
embedding_types = [
WordEmbeddings('glove'), # Pre-trained GloVe embeddings
FlairEmbeddings('news-forward'), # Forward Flair embeddings
FlairEmbeddings('news-backward') # Backward Flair embeddings
]
embeddings = StackedEmbeddings(embeddings=embedding_types)
# Create a new tagger using the updated tag dictionary
new_tagger = SequenceTagger(
hidden_size=256, # Size of the hidden layer, adjust as needed
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
use_crf=True,
)
# Transfer the weights from the old model to the new tagger
new_tagger.load_state_dict(tagger.state_dict(), strict=False)
# Train the model with the new dataset
trainer = ModelTrainer(new_tagger, corpus)
# Start training
trainer.train('./output', # Output folder for the model
learning_rate=0.1,
mini_batch_size=32,
max_epochs=10) # Adjust parameters as needed
# Save the fine-tuned model
new_tagger.save('./trained/fine_tuned_mdl.pt')