keyword/ner_extractor_3800.py
2025-01-20 16:24:18 +00:00

207 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
# from tqdm import tqdm
import time
import datetime
from funcs import save_to_file_by_address, read_file_by_address#, read_from_json
# from pandas import read_excel
import torch
import os
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
os.environ['HF_HOME'] = "/home/admin/HFHOME"
#model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_id = "meta-llama/Llama-3.1-70B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
# SYS_PROMPT = """You receive a Persian legal text and extract from it the keywords that are most important.
# And you don't need to provide explanations or additional text.
# Put each keyword on a single line."
# """# Explain your answer step by step.
SYS_PROMPT = """شما یک مدل زبانی هوش مصنوعی هستید که برای استخراج موجودیت‌های نامدار (NER) از متون طراحی شده‌اید. وظیفه شما استخراج دقیق موجودیت‌های مشخص‌شده از متن ورودی است، بدون تولید یا افزودن هیچ اطلاعاتی خارج از متن اصلی. شما تنها اطلاعاتی را که مستقیماً از متن استخراج شده است، ارائه می‌دهید. موجودیت‌های تکراری را تنها یک بار ذکر می‌کنید. هر موجودیت باید در فرمت زیر ارائه شود:
1. هر موجودیت یک خط جدید باشد.
2. نوع موجودیت، مقدار آن، و جایگاه توکن‌های شروع و پایان آن در متن مشخص شود.
3. در هیچ کدام از موجودیت های نامدار، از استنباط پرهیز کن و دقیقا بر کلماتی که در متن وجود دارد تمرکز کن
4. انواع موجودیت های نامدار مورد نظر:
- `ref`: شامل عناوین دقیق قوانین در متن.
- `h_ref`: عباراتی مرتبط با عناوین قوانین مانند "ماده"، "تبصره"، "بند"، "این قانون"، "قانون مذکور"، "قانون یادشده"، و ...
- `per`: نام اشخاص حقیقی که دقیقا در متن ذکر شده باشد.
- `org`: سازمان‌ها، وزارتخانه‌ها، شرکت‌ها، تشکیلات نظامی و هر مجموعه حقوقی و ساختار مردم نهاد و NGO.
- `loc`: مکان‌ها، شهرها، کشورها، و مناطق جغرافیایی.
- `event`: رویدادهای رسمی و تقویمی.
- `fac`: امکانات و تاسیسات و زیرساخت ها.
- `date`: انواع فرمت های تاریخ‌ به صورت عددی یا حروفی. دقت شود که اعداد با تاریخ اشتباه نشود
- `sub`: موضوع اصلی متن که دقیقاً در متن ذکر شده است.
هیچ توضیح اضافی در پاسخ وجود نداشته باشد. تنها موجودیت‌های استخراج‌شده را در فرمت زیر ارائه کن.
فرمت خروجی: لیستی از موجودیت ها به صورت زیر:
[{'type':'org', 'value':'دیوان محاسبات کشور', 'token_start':'5', 'token_end':'8'}, {'type':'sub', 'value':'حقوق بازنشستگان لشکری', 'token_start':'27', 'token_end':'30'}]
"""# gpt prompt
def format_prompt(SENTENCE):
# PROMPT = f"Persian legal text: {SENTENCE}."
PROMPT = f"متن: {SENTENCE}"
return PROMPT
def kw_count_calculator(text):
keywords_count = (len(text) / 1000) * 15
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
return keywords_count
def generate(input_text):
USER_PROMPT = f"""متن زیر را پردازش کن و موجودیت‌های نامدار را طبق دستورالعمل استخراج کن:
"""
formatted_prompt = input_text[:50000] # to avoid GPU OOM
messages = [
{"role":"system","content":SYS_PROMPT},
{"role":"user","content":USER_PROMPT},
{"role":"user","content":formatted_prompt}
]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=2048,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def do_prompt(sentence):
formatted_prompt = format_prompt(sentence)
results = generate(formatted_prompt).split('\n')
result = [r.strip() for r in results if r.strip()]
return result
def get_tehran_time():
return datetime.datetime.now() + datetime.timedelta(hours=3, minutes=30)
if __name__ == "__main__":
start_time = get_tehran_time()
print(f'start time: {start_time}')
# inputfile = open('./data/main_classes_dataset_03.json', "r", encoding='utf-8')
inputfile = open('./data/new_3800_sections.json', "r", encoding='utf-8')
data = json.load(inputfile)
inputfile.close()
prev_ids = read_file_by_address('./data/prev_ner_ids_3800.txt').splitlines()
counter = 1
file_counter = 1
temp_dict = []
temp_data_text = ''
period_sections = []
period_ids_text = f'***** file number: {file_counter} *****\n'
for item in (data):
if item['id'] in prev_ids:
continue
content = item['content']
try:
item['ners_prompt'] = do_prompt(content)
except:
print(f"section ner error : {item[id]}\n")
counter += 1
continue
period_ids_text += f"{item['id']} \n"
period_sections.append(item)
temp_dict.append(item)
print(f"section:{counter}-id:{item['id']}")
# temp_dict.append(item)
if counter % 1000 == 0:
print(f"period ==>> {(start_time-get_tehran_time())/3600} hours for {counter} sections +++ \n")
outputfile = open(f'./data/sections_ner_llama70b_3800_2_{str(file_counter)}.json', "a+", encoding='utf-8')
outputfile.write(json.dumps(period_sections, ensure_ascii=False))
outputfile.close()
print(f"file {str(file_counter)} created in {str(get_tehran_time())} +++++++++++++++\n ")
# temp_dict.append(item)
file_counter += 1
period_sections = []
# save proccessed sections id for next executions of this code
save_to_file_by_address('./data/prev_ner_ids_3800.txt', period_ids_text.strip())
period_ids_text = ''
if counter == 10:
test_file_name = './data/sections_ner_llama70b_3800_test2_.json'
outputfile = open(test_file_name, "a+", encoding='utf-8')
outputfile.write(json.dumps(temp_dict, ensure_ascii=False))
outputfile.close()
print(f'test file {test_file_name} created ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ')
counter += 1
outputfile = open(f'./data/sections_ner_llama70b_3800_2_{str(file_counter)}.json', "w", encoding='utf-8')
outputfile.write(json.dumps(period_sections, ensure_ascii=False, indent = 4))
outputfile.close()
print(f"file {str(file_counter)} created in {str(get_tehran_time())} +++++++++++++++++++++++++ ")
end_time = get_tehran_time()
print(f"end_time: {end_time}")
print(f"elapsed time: {(end_time-start_time)/3600} Hours!!! ")
print(f"elapsed time: {(end_time-start_time)/86400} Days!!! ")
print("end")
exit()
"""
system prompt version 2 for test:
You are a lawyer and you must be able to explain legal texts without changing technical terms in a way that non-lawyers can understand the meaning of the text.
user prompt version 2 for test:
Extract at least {} important and significant key phrases from the "text" and print the key phrases in the form of a list in Persian and put each key phrase on a new line and do not add any explanation at the beginning or end of the answer.
Each key phrase has a sequential number at the beginning. The key phrases must be present exactly in the text. It is very important and essential that the length of each key phrase has at least two tokens and a single-token key phrase is not acceptable. I emphasize that no key phrase should have only one token. The names of organizations, institutions and legal entities must be considered as key phrases. No key phrase should be a verb or a preposition and should only include nouns that are added together. No key phrase should end with a preposition or the letter "و". It is essential that key phrases do not include "ماده", "تبصره", "بند" or "تاریخ ها".
"""
# Deepseek suggestion
"""
system prompt:
You are a highly accurate and detail-oriented assistant specialized in analyzing Persian legal texts.
user prompt:
Extract at least {} important and significant key phrases from the provided text. Follow these guidelines strictly:
Print the key phrases as a numbered list in Persian, with each key phrase on a new line.
Do not add any explanations, introductions, or conclusions to the output.
Each key phrase must:
Be present exactly in the text.
Consist of at least two tokens (single-token key phrases are not acceptable).
Be a noun phrase (no verbs, prepositions, or single-token words).
Not end with a preposition or the letter "و".
Exclude the following terms: "ماده", "تبصره", "بند", "تاریخ ها".
Include names of organizations, institutions, and legal entities as key phrases.
Ensure the output is clean and adheres to all the above rules.
"""