llama/llama3_classification_ds.py
2025-07-17 20:36:22 +03:30

128 lines
4.4 KiB
Python

"""
این فایل با نرمالایزر هضم کار می کند
"""
import json
with open('./data/classes51new.txt', 'r') as file:
classes = file.read()
classes_list = classes.splitlines()
with open('./data/classification_ds.json', 'r') as file:
sections = json.load(file)
# send content of some sections and classes to llama chat
# and ask for the best class
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import json
from funcs import write_to_json
if torch.cuda.is_available():
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
counter = 0
total = 0
remained = 0
id = ''
keywords_count = 15
def command(text):
global remained
try:
messages = [{"role": "system", "content": "تو یک حقوق دان هستی و باید بتوانی متن های قانونی و حقوقی را به صورت حرفه ای تفسیر کنی. " },
{"role": "user", "content":
'''51 دسته در متن پرامپت می آید. در بین این دسته ها، شماره دسته ای که از نظر معنایی می تواند نزدیک ترین عنوان دسته برای متن باشد را انتخاب کن
تاکید می کنم که فقط اجازه داری یک دسته را انتخاب کنی.
فقط شماره دسته را به صورت یک عدد در خروجی بیاور.
به هیچ عنوان، هیچ توضیح اضافه ای پیش یا پس از شماره دسته در خروجی ننویس.
"متن": {}
'''.format(text)
},
{"role": "user", "content":
'''دسته های 51 گانه عبارت اند از: {}
'''.format(classes)
},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.pad_token_id
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.1,
top_p=0.85,
)
response = outputs[0][input_ids.shape[-1]:]
result = tokenizer.decode(response, skip_special_tokens=True)
return result
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print("Exception: " + str(inst))
counter = 1
if __name__ == "__main__":
start_time = time.time()
try:
classes_dict = []
count = 1
for content_item in sections:
id = sections[counter]['id']
prev_class = sections[counter]['domain_name']
content = sections[counter]['content']
new_class = command(content)
try:
new_class_title = (classes_list[(int(new_class))-1].split('-')[1]).strip()
except:
new_class_title = '-'
print("section " + str(count) + "/" + str(len(sections)) + " class extracting ... ")
classes_dict.append({
'id': count,
'content':content,
'prev-class': prev_class,
'new-class-title': new_class_title,
'new-class': new_class,
})
count+= 1
counter+= 500
if counter > 49387:
break
write_to_json(classes_dict, "./data/result4.json")
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
end_time = time.time()
print(end_time)
operation_time = (int(end_time-start_time)/60)/60
print(f"elapsed time: {operation_time} hours")
print(f" Finished!!! ")