79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
from transformers import pipeline
|
|
from normalizer import cleaning
|
|
import transformers
|
|
import json
|
|
from transformers import AutoTokenizer
|
|
print(transformers.__version__)
|
|
|
|
# model_checkpoint = "./BERT/findtuned_classification_model_15"# 15 epoch
|
|
model_checkpoint = '/home/gpu/tnlp/jokar/Classifier/Models/findtuned_classification_model-15'
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
|
window_size = 200#512
|
|
step_size = 100
|
|
|
|
with open('./data/errors.txt', 'r', encoding='utf-8') as input_file:
|
|
error_sections_id = input_file.read().splitlines()
|
|
|
|
sections = []
|
|
with open('./data/final_sections_class.json', 'r', encoding='utf-8') as input_file:
|
|
sections = json.load(input_file)
|
|
|
|
error_sections = []
|
|
for item in sections:
|
|
if item['id'] in error_sections_id:
|
|
error_sections.append(item)
|
|
|
|
sections = error_sections
|
|
|
|
classifier = pipeline("text-classification", model_checkpoint, framework="pt")
|
|
|
|
def get_class(sentences, top_k:int=4):
|
|
# sentences = cleaning(sentences)
|
|
out = classifier(sentences, top_k=top_k)
|
|
return out
|
|
|
|
def get_window_classes(text):
|
|
text_classes = []
|
|
tokens = tokenizer(text)['input_ids'][1:-1]
|
|
#print(len(tokens))
|
|
if len(tokens) > window_size:
|
|
for i in range(0, len(tokens) - window_size + 1, step_size):
|
|
start_window_slice = tokens[0: i]
|
|
window_slice = tokens[i: i + window_size]
|
|
start_char = len(tokenizer.decode(start_window_slice))
|
|
char_len = len(tokenizer.decode(window_slice))
|
|
context_slice = text[start_char: start_char + char_len]
|
|
tokens_len = len(tokenizer(context_slice)['input_ids'][1:-1])
|
|
print(f'token-len: {tokens_len}', flush=True)
|
|
results = get_class(context_slice)
|
|
text_classes.append(results)
|
|
else:
|
|
text_classes.append(get_class(text))
|
|
|
|
return text_classes
|
|
|
|
errors = ''
|
|
for index, item in enumerate(sections):
|
|
# if index > 11:
|
|
# break
|
|
id = item['id']
|
|
content = item['content']
|
|
try:
|
|
section_classes = get_window_classes(content)
|
|
except Exception as e:
|
|
error = e
|
|
errors+= id + "\n"
|
|
continue
|
|
item['classes'] = section_classes
|
|
print(f'section: {len(sections)}/{index+1}', flush=True)
|
|
|
|
|
|
with open('./data/error_sections_classes.json', 'w', encoding='utf-8') as output_file:
|
|
json_data = json.dumps(sections, indent=4, ensure_ascii=False)
|
|
output_file.write(json_data)
|
|
|
|
with open('./data/errors.txt', 'w', encoding='utf-8') as output_file:
|
|
output_file.write(errors)
|
|
|
|
print('finished!')
|