data_processes/p4_keyword_extractor.py
2025-08-16 14:24:11 +03:30

208 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
استخراج واژگان کلیدی از اجزاء قانونی
"""
import json
import datetime
import torch
import os
from elastic_helper import ElasticHelper
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
os.environ['HF_HOME'] = "/home/admin/HFHOME"
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#model_id = "meta-llama/Llama-3.1-70B-Instruct"
today = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}-{datetime.datetime.now().hour}'
bnb_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
# SYS_PROMPT = """You receive a Persian legal text and extract from it the keywords that are most important.
# And you don't need to provide explanations or additional text.
# Put each keyword on a single line."
# """# Explain your answer step by step.
SYS_PROMPT = """You are a highly accurate and detail-oriented assistant specialized in analyzing Persian legal texts."""# Explain your answer step by step.
SYS_PROMPT = """شما یک دستیار متخصص در استخراج عبارات کلیدی از متون حقوقی فارسی هستید. وظیفه شما این است که از متن ارائه شده توسط کاربر، عبارات اسمی کلیدی مهم و معنادار را استخراج کنید."""# gemini prompt
def format_prompt(SENTENCE):
# PROMPT = f"Persian legal text: {SENTENCE}."
PROMPT = f"متن: {SENTENCE}."
return PROMPT
def kw_count_calculator(text):
keywords_count = (len(text) / 1000) * 15
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
return keywords_count
def generate(formatted_prompt):
keywords_count = kw_count_calculator(formatted_prompt)
# Gemini Prompt
USER_PROMPT = f"""از متن ارائه شده، حداقل {keywords_count} عبارت کلیدی مهم و معنادار را استخراج کنید. خروجی باید به صورت زیر باشد:
• یک لیست فارسی
• هیچ عدد یا علامت و نمادی در ابتدا یا انتهای هر عبارت کلیدی قرار نگیرد
• هر عبارت در یک خط جداگانه
• بدون هیچ توضیح اضافی در ابتدا یا انتهای پاسخ
موارد زیر در استخراج عبارات کلیدی الزامی است:
• بسیار مهم و حیاتی است که عبارات کلیدی باید دقیقاً در متن موجود باشند، بنابراین هرگز از خلاقیت برای ساختن کلیدواژه‌ای که دقیقا در متن وجود ندارد، استفاده نکن.
• طول هر عبارت کلیدی حداقل دو کلمه باشد. عبارات تک-کلمه‌ای قابل قبول نیستند.
• نام سازمان‌ها، مؤسسات و اشخاص حقوقی باید به عنوان عبارات کلیدی در نظر گرفته شوند.
• عبارات کلیدی نباید فعل یا حرف اضافه باشند و فقط باید شامل اسم‌هایی باشند که به هم اضافه شده‌اند (عبارت اسمی).
• عبارات کلیدی نباید به حرف اضافه یا حرف "و" ختم شوند.
• عبارات کلیدی نباید شامل کلمات "ماده"، "تبصره" و "بند" یا تاریخ‌ها باشند.
به عنوان مثال، اگر متن "قانون مدنی جمهوری اسلامی ایران در ماده ۲۲۲ در خصوص مسئولیت مدنی اشخاص حقوقی در تبصره ۱ به موضوع خسارت ناشی از تقصیر کارکنان اشاره دارد." باشد، خروجی باید به شکل زیر باشد:
1. قانون مدنی جمهوری اسلامی ایران
2. مسئولیت مدنی اشخاص حقوقی
3. خسارت ناشی از تقصیر کارکنان
اکنون متن مورد نظر را دریافت خواهید کرد.
"""
formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM
messages = [
{"role":"system","content":SYS_PROMPT},
{"role":"user","content":USER_PROMPT},
{"role":"user","content":formatted_prompt}
]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=2048,
eos_token_id=terminators,
do_sample=True,
temperature=0.1,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def single_section_get_keyword(sentence):
"""
این متد، کلیدواژه های قانونی موجود در متن ورودی را استخراج می کند
**Args:
sentence(str): متن یک سکشن قانونی
**Returns:
kws(list): لیستی از کلیدواژه های تشخیص داده شده
"""
formatted_prompt = format_prompt(sentence)
keywords = generate(formatted_prompt).split('\n')
kws = [kw.strip() for kw in keywords if kw.strip()]
# حذف کلیدواژه های تکراری
kws = list(set(kws))
return kws
def get_sections():
sections_path = "/home/gpu/data_11/14040423/mj_qa_section.zip"
eh_obj = ElasticHelper()
sections = eh_obj.iterateJsonFile(sections_path, True)
sections = convert_to_dict(sections)
return sections
def convert_to_dict(sections):
sections_dict = {}
for item in sections:
id = item['id']
source = item['source']
sections_dict[id] = source
return sections_dict
def do_keyword_extract(sections):
start_time = datetime.datetime.now()
print(f'start time: {start_time}')
# prev_ids = read_file_by_address('./data/prev_kw_ids_170k.txt').splitlines()
counter = 1
file_counter = 1
temp_dict = []
temp_data_text = ''
period_sections = []
period_ids_text = f'***** file number: {file_counter} *****\n'
for index, item in enumerate(sections):
# if item['id'] in prev_ids:
# continue
# # برای تست خروجی
# if counter > 10:
# with open('./data/keyword/prev_kw_ids.txt', 'a+', encoding='utf-8') as file:
# file.write(period_ids_text)
# break
id = item
content = sections[id]['content']
try:
sections[id]['keywords'] = single_section_get_keyword(content)
except Exception as error:
print(f"section kw error : {id}\n")
error_content = f'id: {id} - error: {str(error)}\n'
with open(f'./data/keyword/keyword_errors_{today}.txt', 'a+', encoding='utf-8') as file:
file.write(error_content)
counter += 1
continue
period_ids_text += f"{id} \n"
print(f"section kw extracting: {counter} - id: {id}")
# temp_dict.append(item)
if counter % 5000 == 0:
outputfile = open(f'./data/keyword/sections_kw_llama8b_{str(file_counter)}_{today}.json', "a+", encoding='utf-8')
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent=2))
outputfile.close()
print(f"file {str(file_counter)} created in {str(datetime.datetime.now())} +++++++++++++++++++++++++\n ")
file_counter += 1
period_sections = []
# save proccessed sections id for next executions of this code
with open(f'./data/keyword/prev_kw_ids_{today}.txt', 'a+', encoding='utf-8') as file:
file.write(period_ids_text)
counter += 1
period_ids_text = f'***** file number: {file_counter} *****\n'
outputfile = open(f'./data/keyword/sections_kw_llama8b_{str(file_counter)}_{today}.json', "w", encoding='utf-8')
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent = 4))
outputfile.close()
print(f"file {str(file_counter)} created in {str(datetime.datetime.now())} +++++++++++++++++++++++++ ")
end_time = datetime.datetime.now()
print(f"end_time: {end_time}")
print(f"elapsed time: {(end_time-start_time)} Hours!!! ")
print(f"elapsed time: {(end_time-start_time)/86400} Days!!! ")
print("end")
operation_result = True
return operation_result, sections
if __name__ == "__main__":
print(f'start: {datetime.datetime.now()}')
sections = get_sections()
operation_result = do_keyword_extract(sections)
print(f'end: {datetime.datetime.now()}')