Classifier/sections_window_1.py
2025-07-13 17:32:37 +03:30

79 lines
2.5 KiB
Python

from transformers import pipeline
from normalizer import cleaning
import transformers
import json
from transformers import AutoTokenizer
print(transformers.__version__)
# model_checkpoint = "./BERT/findtuned_classification_model_15"# 15 epoch
model_checkpoint = '/home/gpu/tnlp/jokar/Classifier/Models/findtuned_classification_model-15'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
window_size = 200#512
step_size = 100
with open('./data/errors.txt', 'r', encoding='utf-8') as input_file:
error_sections_id = input_file.read().splitlines()
sections = []
with open('./data/final_sections_class.json', 'r', encoding='utf-8') as input_file:
sections = json.load(input_file)
error_sections = []
for item in sections:
if item['id'] in error_sections_id:
error_sections.append(item)
sections = error_sections
classifier = pipeline("text-classification", model_checkpoint, framework="pt")
def get_class(sentences, top_k:int=4):
# sentences = cleaning(sentences)
out = classifier(sentences, top_k=top_k)
return out
def get_window_classes(text):
text_classes = []
tokens = tokenizer(text)['input_ids'][1:-1]
#print(len(tokens))
if len(tokens) > window_size:
for i in range(0, len(tokens) - window_size + 1, step_size):
start_window_slice = tokens[0: i]
window_slice = tokens[i: i + window_size]
start_char = len(tokenizer.decode(start_window_slice))
char_len = len(tokenizer.decode(window_slice))
context_slice = text[start_char: start_char + char_len]
tokens_len = len(tokenizer(context_slice)['input_ids'][1:-1])
print(f'token-len: {tokens_len}', flush=True)
results = get_class(context_slice)
text_classes.append(results)
else:
text_classes.append(get_class(text))
return text_classes
errors = ''
for index, item in enumerate(sections):
# if index > 11:
# break
id = item['id']
content = item['content']
try:
section_classes = get_window_classes(content)
except Exception as e:
error = e
errors+= id + "\n"
continue
item['classes'] = section_classes
print(f'section: {len(sections)}/{index+1}', flush=True)
with open('./data/error_sections_classes.json', 'w', encoding='utf-8') as output_file:
json_data = json.dumps(sections, indent=4, ensure_ascii=False)
output_file.write(json_data)
with open('./data/errors.txt', 'w', encoding='utf-8') as output_file:
output_file.write(errors)
print('finished!')