209 lines
9.0 KiB
Python
209 lines
9.0 KiB
Python
"""
|
||
استخراج واژگان کلیدی از اجزاء قانونی
|
||
"""
|
||
|
||
import json
|
||
import datetime
|
||
import torch
|
||
import os
|
||
from elastic_helper import ElasticHelper
|
||
from transformers import AutoTokenizer, AutoModel
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||
|
||
os.environ['HF_HOME'] = "/home/admin/HFHOME"
|
||
|
||
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||
#model_id = "meta-llama/Llama-3.1-70B-Instruct"
|
||
|
||
date = datetime.datetime.now()
|
||
today = f'{date.year}-{date.month}-{date.day}-{date.hour}'
|
||
|
||
bnb_config = BitsAndBytesConfig(
|
||
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
|
||
)
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch.bfloat16,
|
||
device_map="auto",
|
||
quantization_config=bnb_config
|
||
)
|
||
terminators = [
|
||
tokenizer.eos_token_id,
|
||
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||
]
|
||
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
|
||
|
||
# SYS_PROMPT = """You receive a Persian legal text and extract from it the keywords that are most important.
|
||
# And you don't need to provide explanations or additional text.
|
||
# Put each keyword on a single line."
|
||
# """# Explain your answer step by step.
|
||
|
||
SYS_PROMPT = """You are a highly accurate and detail-oriented assistant specialized in analyzing Persian legal texts."""# Explain your answer step by step.
|
||
SYS_PROMPT = """شما یک دستیار متخصص در استخراج عبارات کلیدی از متون حقوقی فارسی هستید. وظیفه شما این است که از متن ارائه شده توسط کاربر، عبارات اسمی کلیدی مهم و معنادار را استخراج کنید."""# gemini prompt
|
||
|
||
|
||
def format_prompt(SENTENCE):
|
||
# PROMPT = f"Persian legal text: {SENTENCE}."
|
||
PROMPT = f"متن: {SENTENCE}."
|
||
return PROMPT
|
||
|
||
def kw_count_calculator(text):
|
||
keywords_count = (len(text) / 1000) * 15
|
||
keywords_count = int(keywords_count)
|
||
if keywords_count == 0:
|
||
keywords_count = 1
|
||
return keywords_count
|
||
|
||
def generate(formatted_prompt):
|
||
keywords_count = kw_count_calculator(formatted_prompt)
|
||
|
||
# Gemini Prompt
|
||
USER_PROMPT = f"""از متن ارائه شده، حداقل {keywords_count} عبارت کلیدی مهم و معنادار را استخراج کنید. خروجی باید به صورت زیر باشد:
|
||
• یک لیست فارسی
|
||
• هیچ عدد یا علامت و نمادی در ابتدا یا انتهای هر عبارت کلیدی قرار نگیرد
|
||
• هر عبارت در یک خط جداگانه
|
||
• بدون هیچ توضیح اضافی در ابتدا یا انتهای پاسخ
|
||
موارد زیر در استخراج عبارات کلیدی الزامی است:
|
||
• بسیار مهم و حیاتی است که عبارات کلیدی باید دقیقاً در متن موجود باشند، بنابراین هرگز از خلاقیت برای ساختن کلیدواژهای که دقیقا در متن وجود ندارد، استفاده نکن.
|
||
• طول هر عبارت کلیدی حداقل دو کلمه باشد. عبارات تک-کلمهای قابل قبول نیستند.
|
||
• نام سازمانها، مؤسسات و اشخاص حقوقی باید به عنوان عبارات کلیدی در نظر گرفته شوند.
|
||
• عبارات کلیدی نباید فعل یا حرف اضافه باشند و فقط باید شامل اسمهایی باشند که به هم اضافه شدهاند (عبارت اسمی).
|
||
• عبارات کلیدی نباید به حرف اضافه یا حرف "و" ختم شوند.
|
||
• عبارات کلیدی نباید شامل کلمات "ماده"، "تبصره" و "بند" یا تاریخها باشند.
|
||
به عنوان مثال، اگر متن "قانون مدنی جمهوری اسلامی ایران در ماده ۲۲۲ در خصوص مسئولیت مدنی اشخاص حقوقی در تبصره ۱ به موضوع خسارت ناشی از تقصیر کارکنان اشاره دارد." باشد، خروجی باید به شکل زیر باشد:
|
||
1. قانون مدنی جمهوری اسلامی ایران
|
||
2. مسئولیت مدنی اشخاص حقوقی
|
||
3. خسارت ناشی از تقصیر کارکنان
|
||
اکنون متن مورد نظر را دریافت خواهید کرد.
|
||
"""
|
||
formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM
|
||
messages = [
|
||
{"role":"system","content":SYS_PROMPT},
|
||
{"role":"user","content":USER_PROMPT},
|
||
{"role":"user","content":formatted_prompt}
|
||
]
|
||
# tell the model to generate
|
||
input_ids = tokenizer.apply_chat_template(
|
||
messages,
|
||
add_generation_prompt=True,
|
||
return_tensors="pt"
|
||
).to(model.device)
|
||
outputs = model.generate(
|
||
input_ids,
|
||
max_new_tokens=2048,
|
||
eos_token_id=terminators,
|
||
do_sample=True,
|
||
temperature=0.1,
|
||
top_p=0.9,
|
||
)
|
||
response = outputs[0][input_ids.shape[-1]:]
|
||
return tokenizer.decode(response, skip_special_tokens=True)
|
||
|
||
|
||
def single_section_get_keyword(sentence):
|
||
"""
|
||
این متد، کلیدواژه های قانونی موجود در متن ورودی را استخراج می کند
|
||
|
||
**Args:
|
||
sentence(str): متن یک سکشن قانونی
|
||
**Returns:
|
||
kws(list): لیستی از کلیدواژه های تشخیص داده شده
|
||
"""
|
||
formatted_prompt = format_prompt(sentence)
|
||
keywords = generate(formatted_prompt).split('\n')
|
||
kws = [kw.strip() for kw in keywords if kw.strip()]
|
||
# حذف کلیدواژه های تکراری
|
||
kws = list(set(kws))
|
||
return kws
|
||
|
||
def get_sections():
|
||
sections_path = "/home/gpu/data_11/14040423/mj_qa_section.zip"
|
||
eh_obj = ElasticHelper()
|
||
sections = eh_obj.iterateJsonFile(sections_path, True)
|
||
sections = convert_to_dict(sections)
|
||
return sections
|
||
|
||
def convert_to_dict(sections):
|
||
sections_dict = {}
|
||
for item in sections:
|
||
id = item['id']
|
||
source = item['source']
|
||
sections_dict[id] = source
|
||
|
||
return sections_dict
|
||
|
||
def do_keyword_extract(sections):
|
||
start_time = datetime.datetime.now()
|
||
print(f'start time: {start_time}')
|
||
|
||
# prev_ids = read_file_by_address('./data/prev_kw_ids_170k.txt').splitlines()
|
||
counter = 1
|
||
file_counter = 1
|
||
temp_dict = []
|
||
temp_data_text = ''
|
||
period_sections = []
|
||
period_ids_text = f'***** file number: {file_counter} *****\n'
|
||
for index, item in enumerate(sections):
|
||
# if item['id'] in prev_ids:
|
||
# continue
|
||
|
||
# # برای تست خروجی
|
||
# if counter > 10:
|
||
# with open('./data/keyword/prev_kw_ids.txt', 'a+', encoding='utf-8') as file:
|
||
# file.write(period_ids_text)
|
||
# break
|
||
|
||
id = item
|
||
content = sections[id]['content']
|
||
try:
|
||
sections[id]['keywords'] = single_section_get_keyword(content)
|
||
except Exception as error:
|
||
print(f"section kw error : {id}\n")
|
||
error_content = f'id: {id} - error: {str(error)}\n'
|
||
with open(f'./data/keyword/keyword_errors_{today}.txt', 'a+', encoding='utf-8') as file:
|
||
file.write(error_content)
|
||
counter += 1
|
||
continue
|
||
|
||
period_ids_text += f"{id} \n"
|
||
|
||
print(f"section kw extracting: {counter} - id: {id}")
|
||
# temp_dict.append(item)
|
||
if counter % 5000 == 0:
|
||
outputfile = open(f'./data/keyword/sections_kw_llama8b_{str(file_counter)}_{today}.json', "a+", encoding='utf-8')
|
||
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent=2))
|
||
outputfile.close()
|
||
print(f"file {str(file_counter)} created in {str(datetime.datetime.now())} +++++++++++++++++++++++++\n ")
|
||
|
||
file_counter += 1
|
||
period_sections = []
|
||
# save proccessed sections id for next executions of this code
|
||
with open(f'./data/keyword/prev_kw_ids_{today}.txt', 'a+', encoding='utf-8') as file:
|
||
file.write(period_ids_text)
|
||
counter += 1
|
||
period_ids_text = f'***** file number: {file_counter} *****\n'
|
||
|
||
|
||
outputfile = open(f'./data/keyword/sections_kw_llama8b_{str(file_counter)}_{today}.json', "w", encoding='utf-8')
|
||
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent = 4))
|
||
outputfile.close()
|
||
print(f"file {str(file_counter)} created in {str(datetime.datetime.now())} +++++++++++++++++++++++++ ")
|
||
|
||
end_time = datetime.datetime.now()
|
||
print(f"end_time: {end_time}")
|
||
print(f"elapsed time: {(end_time-start_time)} Hours!!! ")
|
||
print(f"elapsed time: {(end_time-start_time)/86400} Days!!! ")
|
||
print("end")
|
||
|
||
operation_result = True
|
||
return operation_result, sections
|
||
|
||
if __name__ == "__main__":
|
||
print(f'start: {datetime.datetime.now()}')
|
||
sections = get_sections()
|
||
|
||
operation_result = do_keyword_extract(sections)
|
||
|
||
print(f'end: {datetime.datetime.now()}') |