128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
"""
|
|
این فایل با نرمالایزر هضم کار می کند
|
|
"""
|
|
import json
|
|
|
|
with open('./data/classes51new.txt', 'r') as file:
|
|
classes = file.read()
|
|
|
|
classes_list = classes.splitlines()
|
|
|
|
with open('./data/classification_ds.json', 'r') as file:
|
|
sections = json.load(file)
|
|
|
|
# send content of some sections and classes to llama chat
|
|
# and ask for the best class
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
import torch
|
|
import time
|
|
import json
|
|
from funcs import write_to_json
|
|
|
|
if torch.cuda.is_available():
|
|
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
counter = 0
|
|
total = 0
|
|
remained = 0
|
|
id = ''
|
|
keywords_count = 15
|
|
|
|
def command(text):
|
|
global remained
|
|
try:
|
|
|
|
messages = [{"role": "system", "content": "تو یک حقوق دان هستی و باید بتوانی متن های قانونی و حقوقی را به صورت حرفه ای تفسیر کنی. " },
|
|
|
|
{"role": "user", "content":
|
|
'''51 دسته در متن پرامپت می آید. در بین این دسته ها، شماره دسته ای که از نظر معنایی می تواند نزدیک ترین عنوان دسته برای متن باشد را انتخاب کن
|
|
تاکید می کنم که فقط اجازه داری یک دسته را انتخاب کنی.
|
|
فقط شماره دسته را به صورت یک عدد در خروجی بیاور.
|
|
به هیچ عنوان، هیچ توضیح اضافه ای پیش یا پس از شماره دسته در خروجی ننویس.
|
|
"متن": {}
|
|
'''.format(text)
|
|
},
|
|
{"role": "user", "content":
|
|
'''دسته های 51 گانه عبارت اند از: {}
|
|
'''.format(classes)
|
|
},
|
|
]
|
|
|
|
input_ids = tokenizer.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=True,
|
|
return_tensors="pt"
|
|
).to(model.device)
|
|
|
|
terminators = [
|
|
tokenizer.eos_token_id,
|
|
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
|
]
|
|
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
|
|
|
outputs = model.generate(
|
|
input_ids,
|
|
max_new_tokens=256,
|
|
eos_token_id=terminators,
|
|
do_sample=True,
|
|
temperature=0.1,
|
|
top_p=0.85,
|
|
)
|
|
|
|
response = outputs[0][input_ids.shape[-1]:]
|
|
result = tokenizer.decode(response, skip_special_tokens=True)
|
|
|
|
return result
|
|
|
|
except Exception as inst:
|
|
print(type(inst)) # the exception type
|
|
print(inst.args) # arguments stored in .args
|
|
print("Exception: " + str(inst))
|
|
|
|
counter = 1
|
|
if __name__ == "__main__":
|
|
start_time = time.time()
|
|
try:
|
|
classes_dict = []
|
|
count = 1
|
|
for content_item in sections:
|
|
|
|
id = sections[counter]['id']
|
|
|
|
prev_class = sections[counter]['domain_name']
|
|
content = sections[counter]['content']
|
|
|
|
new_class = command(content)
|
|
try:
|
|
new_class_title = (classes_list[(int(new_class))-1].split('-')[1]).strip()
|
|
except:
|
|
new_class_title = '-'
|
|
|
|
print("section " + str(count) + "/" + str(len(sections)) + " class extracting ... ")
|
|
classes_dict.append({
|
|
'id': count,
|
|
'content':content,
|
|
'prev-class': prev_class,
|
|
'new-class-title': new_class_title,
|
|
'new-class': new_class,
|
|
})
|
|
count+= 1
|
|
counter+= 500
|
|
|
|
if counter > 49387:
|
|
break
|
|
write_to_json(classes_dict, "./data/result4.json")
|
|
|
|
except Exception as inst:
|
|
print(type(inst)) # the exception type
|
|
print(inst.args) # arguments stored in .args
|
|
|
|
end_time = time.time()
|
|
print(end_time)
|
|
operation_time = (int(end_time-start_time)/60)/60
|
|
print(f"elapsed time: {operation_time} hours")
|
|
print(f" Finished!!! ")
|