Classifier/classification_dataset_add_fullpath.py

107 lines
3.4 KiB
Python

from transformers import pipeline
from normalizer import cleaning
from elastic_helper import ElasticHelper
import transformers
import json
import datetime
import pandas as pd
from transformers import AutoTokenizer
print(transformers.__version__)
#model_checkpoint = "./BERT/findtuned_classification_model_15"# 15 epoch
model_checkpoint = '/home/gpu/tnlp/jokar/Classifier/Models/findtuned_classification_model-15'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
window_size = tokenizer.model_max_length#512#200
step_size = 350#100
Top_k = 10
# with open('./data/errors.txt', 'r', encoding='utf-8') as input_file:
# error_sections_id = input_file.read().splitlines()
eh_obj = ElasticHelper()
path = "/home/gpu/data_11/mj_qa_section.zip"
sections = eh_obj.iterateJsonFile(path, True)
classifier = pipeline("text-classification", model_checkpoint, framework="pt")
print(f'start: {datetime.datetime.now()}')
with open('./data/Dataset_classification_v2.json', 'r', encoding='utf-8') as input_file:
dataset = json.load(input_file)
dataset_dict = {}
for itm in dataset:
try:
item_content = itm['content']
except:
print(f'empty section content: {itm["id"]}')
continue
dataset_dict[itm["id"]] = {
"domain_id": itm["domain_id"],
"domain_name": itm['domain_name'],
"content": item_content
}
dataset_ids = [item['id'] for item in dataset]
print(f'len(dataset_ids): {len(dataset_ids)}')
cc_counter = 1
test_counter = 1
all = 282671
qanon_title_list = []
new_sections_dict = {}
failed_counter = 1
new_dataset = []
for index, item in enumerate(sections):
# if index > 500:
# break
id = item['id']
if not id in dataset_dict:
continue
source = item['source']
qanon_title = source['qanon_title']
full_path = source['other_info']['full_path'].split(">")
full_path_text = ''
# if id == 'qs2121517':
# pass
for i, path_item in enumerate(reversed(full_path)):
if i == len(full_path) - 1:
full_path_text += ''.join(f'{path_item}')
break
if path_item == 'تبصره':
full_path_text += ''.join(f'{path_item} ')
else:
full_path_text += ''.join(f'{path_item} از ')
full_path_text = full_path_text.strip()
try:
content = cleaning(dataset_dict[id]['content'])
pre_content = f"محتوای {full_path_text} {cleaning(qanon_title)} متن زیر است.\n"
new_content = f"{pre_content} {content}"
if (len(tokenizer(new_content)['input_ids'][1:-1]) > 512):
failed_counter +=1
continue
new_dataset.append({
"id": id,
"domain_id": dataset_dict[id]["domain_id"],
"domain_name": dataset_dict[id]['domain_name'],
"content": content,
"content_with_fullpath": new_content,
})
except Exception as e:
with open('./data/errors_log_dataset.txt', 'a', encoding='utf-8') as output_file:
output_file.write(id + " >> " + str(e) + "\n")
continue
# print(f'section: {all}/{id}/{index+1}', flush=True)
with open('./data/Dataset_classification_v2_fullpath.json', 'w', encoding='utf-8') as output_file:
json_data = json.dumps(new_dataset, indent=4, ensure_ascii=False)
output_file.write(json_data)
print(f"failed_counter ::: {failed_counter}")
print(f'end: {datetime.datetime.now()}')
print('finished!')