from transformers import pipeline from normalizer import cleaning import transformers import json from transformers import AutoTokenizer print(transformers.__version__) # model_checkpoint = "./BERT/findtuned_classification_model_15"# 15 epoch model_checkpoint = '/home/gpu/tnlp/jokar/Classifier/Models/findtuned_classification_model-15' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) window_size = 200#512 step_size = 100 with open('./data/errors.txt', 'r', encoding='utf-8') as input_file: error_sections_id = input_file.read().splitlines() sections = [] with open('./data/final_sections_class.json', 'r', encoding='utf-8') as input_file: sections = json.load(input_file) error_sections = [] for item in sections: if item['id'] in error_sections_id: error_sections.append(item) sections = error_sections classifier = pipeline("text-classification", model_checkpoint, framework="pt") def get_class(sentences, top_k:int=4): # sentences = cleaning(sentences) out = classifier(sentences, top_k=top_k) return out def get_window_classes(text): text_classes = [] tokens = tokenizer(text)['input_ids'][1:-1] #print(len(tokens)) if len(tokens) > window_size: for i in range(0, len(tokens) - window_size + 1, step_size): start_window_slice = tokens[0: i] window_slice = tokens[i: i + window_size] start_char = len(tokenizer.decode(start_window_slice)) char_len = len(tokenizer.decode(window_slice)) context_slice = text[start_char: start_char + char_len] tokens_len = len(tokenizer(context_slice)['input_ids'][1:-1]) print(f'token-len: {tokens_len}', flush=True) results = get_class(context_slice) text_classes.append(results) else: text_classes.append(get_class(text)) return text_classes errors = '' for index, item in enumerate(sections): # if index > 11: # break id = item['id'] content = item['content'] try: section_classes = get_window_classes(content) except Exception as e: error = e errors+= id + "\n" continue item['classes'] = section_classes print(f'section: {len(sections)}/{index+1}', flush=True) with open('./data/error_sections_classes.json', 'w', encoding='utf-8') as output_file: json_data = json.dumps(sections, indent=4, ensure_ascii=False) output_file.write(json_data) with open('./data/errors.txt', 'w', encoding='utf-8') as output_file: output_file.write(errors) print('finished!')