data_processes/p4_keyword_extractor.py
2025-08-10 18:14:42 +03:30

193 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
استخراج واژگان کلیدی از اجزاء قانونی
"""
import json
import datetime
import torch
import os
from elastic_helper import ElasticHelper
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
os.environ['HF_HOME'] = "/home/admin/HFHOME"
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#model_id = "meta-llama/Llama-3.1-70B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
# SYS_PROMPT = """You receive a Persian legal text and extract from it the keywords that are most important.
# And you don't need to provide explanations or additional text.
# Put each keyword on a single line."
# """# Explain your answer step by step.
SYS_PROMPT = """You are a highly accurate and detail-oriented assistant specialized in analyzing Persian legal texts."""# Explain your answer step by step.
SYS_PROMPT = """شما یک دستیار متخصص در استخراج عبارات کلیدی از متون حقوقی فارسی هستید. وظیفه شما این است که از متن ارائه شده توسط کاربر، عبارات اسمی کلیدی مهم و معنادار را استخراج کنید."""# gemini prompt
def format_prompt(SENTENCE):
# PROMPT = f"Persian legal text: {SENTENCE}."
PROMPT = f"متن: {SENTENCE}."
return PROMPT
def kw_count_calculator(text):
keywords_count = (len(text) / 1000) * 15
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
return keywords_count
def generate(formatted_prompt):
keywords_count = kw_count_calculator(formatted_prompt)
# Gemini Prompt
USER_PROMPT = f"""از متن ارائه شده، حداقل {keywords_count} عبارت کلیدی مهم و معنادار را استخراج کنید. خروجی باید به صورت زیر باشد:
• یک لیست فارسی
• شماره ترتیبی در ابتدای هر عبارت
• هر عبارت در یک خط جداگانه *
• بدون هیچ توضیح اضافی در ابتدا یا انتهای پاسخ
موارد زیر در استخراج عبارات کلیدی الزامی است:
• بسیار مهم و حیاتی است که عبارات کلیدی باید دقیقاً در متن موجود باشند، بنابراین هرگز از خلاقیت برای ساختن کلیدواژه‌ای که دقیقا در متن وجود ندارد، استفاده نکن.
• طول هر عبارت کلیدی حداقل دو کلمه باشد. عبارات تک-کلمه‌ای قابل قبول نیستند.
• نام سازمان‌ها، مؤسسات و اشخاص حقوقی باید به عنوان عبارات کلیدی در نظر گرفته شوند.
• عبارات کلیدی نباید فعل یا حرف اضافه باشند و فقط باید شامل اسم‌هایی باشند که به هم اضافه شده‌اند (عبارت اسمی).
• عبارات کلیدی نباید به حرف اضافه یا حرف "و" ختم شوند.
• عبارات کلیدی نباید شامل کلمات "ماده"، "تبصره" و "بند" یا تاریخ‌ها باشند.
به عنوان مثال، اگر متن "قانون مدنی جمهوری اسلامی ایران در ماده ۲۲۲ در خصوص مسئولیت مدنی اشخاص حقوقی در تبصره ۱ به موضوع خسارت ناشی از تقصیر کارکنان اشاره دارد." باشد، خروجی باید به شکل زیر باشد:
1. قانون مدنی جمهوری اسلامی ایران
2. مسئولیت مدنی اشخاص حقوقی
3. خسارت ناشی از تقصیر کارکنان
اکنون متن مورد نظر را دریافت خواهید کرد.
"""
formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM
messages = [
{"role":"system","content":SYS_PROMPT},
{"role":"user","content":USER_PROMPT},
{"role":"user","content":formatted_prompt}
]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=2048,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def get_rules(sentence):
formatted_prompt = format_prompt(sentence)
rules = generate(formatted_prompt).split('\n')
result = [r.strip() for r in rules if r.strip()]
return result
def get_sections():
sections_path = "/home/gpu/data_11/14040423/mj_qa_section.zip"
eh_obj = ElasticHelper()
sections = eh_obj.iterateJsonFile(sections_path, True)
sections = convert_to_dict(sections)
return sections
def convert_to_dict(sections):
sections_dict = {}
for item in sections:
id = item['id']
source = item['source']
sections_dict[id] = source
return sections_dict
def do_keyword_extract(sections):
start_time = datetime.datetime.now()
print(f'start time: {start_time}')
# prev_ids = read_file_by_address('./data/prev_kw_ids_170k.txt').splitlines()
counter = 1
file_counter = 1
temp_dict = []
temp_data_text = ''
period_sections = []
period_ids_text = f'***** file number: {file_counter} *****\n'
for index, item in enumerate(sections):
# if item['id'] in prev_ids:
# continue
if counter > 10:
with open('./data/keyword/prev_kw_ids.txt', 'a+', encoding='utf-8') as file:
file.write(period_ids_text)
break
id = item
content = sections[id]['content']
try:
sections[id]['keywords'] = get_rules(content)
except:
print(f"section kw error : {id}\n")
counter += 1
continue
period_ids_text += f"{id} \n"
print(f"section: {counter}-id: {id}")
# temp_dict.append(item)
if counter % 1000 == 0:
outputfile = open(f'./data/keyword/sections_kw_llama8b_{str(file_counter)}.json', "a+", encoding='utf-8')
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent=2))
outputfile.close()
print(f"file {str(file_counter)} created in {str(datetime.datetime.now())} +++++++++++++++++++++++++\n ")
file_counter += 1
period_sections = []
# save proccessed sections id for next executions of this code
with open('./data/keyword/prev_kw_ids.txt', 'a+', encoding='utf-8') as file:
file.write(period_ids_text)
counter += 1
period_ids_text = f'***** file number: {file_counter} *****\n'
outputfile = open(f'./data/keyword/sections_kw_llama8b_{str(file_counter)}.json', "w", encoding='utf-8')
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent = 4))
outputfile.close()
print(f"file {str(file_counter)} created in {str(datetime.datetime.now())} +++++++++++++++++++++++++ ")
end_time = datetime.datetime.now()
print(f"end_time: {end_time}")
print(f"elapsed time: {(end_time-start_time)} Hours!!! ")
print(f"elapsed time: {(end_time-start_time)/86400} Days!!! ")
print("end")
return True
if __name__ == "__main__":
print(f'start: {datetime.datetime.now()}')
sections = get_sections()
sections = do_keyword_extract(sections)
print(f'end: {datetime.datetime.now()}')