keyword_llama/conflict.py

94 lines
3.2 KiB
Python
Raw Permalink Normal View History

2025-01-20 16:35:35 +00:00
import json
from tqdm import tqdm
import time
import torch
import os
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
os.environ['HF_HOME'] = "/home/admin/HFHOME"
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#model_id = "meta-llama/Llama-3.1-70B-Instruct"
# use quantization to lower GPU usage
# 4 bit:
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
# )
# 8 bit:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
SYS_PROMPT = """
You receive two Persian legal rule texts and analyze them carefully, explain your answer step by step, to see whether these two rules logically conflict with each other.
Finally, state the final conclusion for the presence or absence of conflict with the words "yes" or "no".
"""# Explain your answer step by step.
def format_prompt(SENTENCE1, SENTENCE2):
PROMPT = f"Rule 1: {SENTENCE1}. Rule 2: {SENTENCE2}."
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=2048,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def conflict(sentence1, sentence2):
formatted_prompt = format_prompt(sentence1, sentence2)
return generate(formatted_prompt)
if __name__ == "__main__":
print('start')
start_time = time.time()
#inputfile = open('./main_rules_lama70B_dataset_02.json', "r", encoding='utf-8')
result = conflict("حسین در حال حاضر در وزارت نفت و انرژی به استخدام دولت در آمده است.",
"حسین الان در دانشگاه دولتی مشغول به تحصیل است و هرکسی که در حال تحصیل باشد، از نظر قانون نمی تواند در دولت استخدام شود")
end_time = time.time()
print("*********************************************************")
print("*********************************************************")
print()
print(result)
print()
print("*********************************************************")
print("*********************************************************")
print(f"elapsed time: {end_time-start_time}")
print("end")