keyword/llama_givechi_ttl.py

233 lines
8.5 KiB
Python
Raw Normal View History

2025-01-20 16:24:18 +00:00
"""
this code reads a json file which has sections and generates
keywords for it by llama and then stores merge of old data and
new generated keywords
"""
import json
# from tqdm import tqdm
import time
import datetime
from funcs import save_to_file_by_address, read_file_by_address#, read_from_json
# from pandas import read_excel
import torch
import os
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# read sections in json file
inputfile = open('./data/sections_1060.json', "r", encoding='utf-8')
data = json.load(inputfile)
inputfile.close()
os.environ['HF_HOME'] = "/home/admin/HFHOME"
#model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_id = "meta-llama/Llama-3.1-70B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
# SYS_PROMPT = """You receive a Persian legal text and extract from it the keywords that are most important.
# And you don't need to provide explanations or additional text.
# Put each keyword on a single line."
# """# Explain your answer step by step.
SYS_PROMPT = """Purpose
The purpose of this prompt is to identify translatable verbs in TTL files derived from VerbNet and add Persian labels to them. This should be done taking into account the ontology structure and semantic constraints from VerbNet.
## Input
A TTL file that contains:
- Definition of verb classes in English
- Hierarchical relationships (by rdfs:subClassOf)
- Semantic constraints (by owl:Restriction)
- Prefixes and base namespaces
## Analysis and translation instructions
1. Check the basic structure of the file:
- Identify the main class and subclasses
- Check the constraints, such as EventHasAgent, EventHasMaterial.
2. Verb analysis:
- Identify verbs that do not have a number in their identifier.
- Remove superclasses (such as build-26.1) from the translation list
- Check the semantic relationship for each verb and pay attention to subtle semantic differences
## The main rule of multiple equivalent verbs
Some English verbs have multiple valid translation to Persian. For these cases, you should:
- Create a separate triple for each Persian equivalent with rdfs:label, following turtle syntax.
For example, if the verb cut has three equivalents:
go:cut rdfs:label "بريدن"@fa ,
"برش دادن"@fa ,
"قطع کردن"@fa .
## Translation instructions
1. Identify verbs: Only translate verbs without numbers in the URI
2. Find equivalents:
- Identify all valid Persian translation.
- Each translation should be recorded separately.
3. Output writing:
- Fixed format for a triple: URI + rdfs:label + equivalent in a quote + @fa .
## some examples
For the verb arrange which has three equivalents:
go:arrange rdfs:label "چیدن"@fa,
"مرتب کردن"@fa,
"تنظیم کردن"@fa.
For the verb build which has two equivalents:
go:build rdfs:label "ساختن"@fa .
go:build rdfs:label "بنا کردن"@fa .
## Final output
The output should only contain translation lines, contain no triple from the original file.
## Important Note
Note that your answer should only be like an example and not have any explanation.
"""# givechi prompt
def format_prompt(SENTENCE):
# PROMPT = f"Persian legal text: {SENTENCE}."
PROMPT = f"متن: {SENTENCE}."
return PROMPT
def kw_count_calculator(text):
keywords_count = (len(text) / 1000) * 15
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
return keywords_count
def generate(formatted_prompt):
keywords_count = kw_count_calculator(formatted_prompt)
# Gemini Prompt
formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM
messages = [
{"role":"system","content":SYS_PROMPT},
{"role":"user","content":formatted_prompt}
]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=2048,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def get_rules(sentence):
formatted_prompt = format_prompt(sentence)
rules = generate(formatted_prompt).split('\n')
result = [r.strip() for r in rules if r.strip()]
return result
if __name__ == "__main__":
start_time = datetime.datetime.now()
print(f'start time: {start_time}')
# Get the list of all files and directories
path = "."
dir_list = os.listdir(path)
counter = 1
file_counter = 1
temp_dict = []
temp_data_text = ''
period_sections = []
period_ids_text = f'***** file number: {file_counter} *****\n'
final_results = []
for ttl_file_address in dir_list:
ttl_content = read_file_by_address(ttl_file_address)
try:
prompt_result = get_rules(ttl_content)
except:
print(f"ttl_error: {ttl_file_address}\n")
counter += 1
continue
final_results.append({
"ttl_file_address": ttl_file_address,
"ttl_file_content": ttl_content,
"prompt_result": prompt_result
})
print(f"file: {counter} - {ttl_file_address}")
# temp_dict.append(item)
file_counter += 1
# save proccessed sections id for next executions of this code
# save_to_file_by_address('./data/prev_ids_ttl.txt', period_ids_text.strip())
counter += 1
outputfile = open(f'./data/result_ttl_prompt.json', "w", encoding='utf-8')
outputfile.write(json.dumps(final_results, ensure_ascii=False, indent = 4))
outputfile.close()
end_time = datetime.datetime.now()
print(f"end_time: {end_time}")
print(f"elapsed time: {(end_time-start_time)/3600} Hours!!! ")
print(f"elapsed time: {(end_time-start_time)/86400} Days!!! ")
print("end")
exit()
"""
system prompt version 2 for test:
You are a lawyer and you must be able to explain legal texts without changing technical terms in a way that non-lawyers can understand the meaning of the text.
user prompt version 2 for test:
Extract at least {} important and significant key phrases from the "text" and print the key phrases in the form of a list in Persian and put each key phrase on a new line and do not add any explanation at the beginning or end of the answer.
Each key phrase has a sequential number at the beginning. The key phrases must be present exactly in the text. It is very important and essential that the length of each key phrase has at least two tokens and a single-token key phrase is not acceptable. I emphasize that no key phrase should have only one token. The names of organizations, institutions and legal entities must be considered as key phrases. No key phrase should be a verb or a preposition and should only include nouns that are added together. No key phrase should end with a preposition or the letter "و". It is essential that key phrases do not include "ماده", "تبصره", "بند" or "تاریخ ها".
"""
# Deepseek suggestion
"""
system prompt:
You are a highly accurate and detail-oriented assistant specialized in analyzing Persian legal texts.
user prompt:
Extract at least {} important and significant key phrases from the provided text. Follow these guidelines strictly:
Print the key phrases as a numbered list in Persian, with each key phrase on a new line.
Do not add any explanations, introductions, or conclusions to the output.
Each key phrase must:
Be present exactly in the text.
Consist of at least two tokens (single-token key phrases are not acceptable).
Be a noun phrase (no verbs, prepositions, or single-token words).
Not end with a preposition or the letter "و".
Exclude the following terms: "ماده", "تبصره", "بند", "تاریخ ها".
Include names of organizations, institutions, and legal entities as key phrases.
Ensure the output is clean and adheres to all the above rules.
"""