94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
|
import json
|
||
|
from tqdm import tqdm
|
||
|
import time
|
||
|
|
||
|
|
||
|
import torch
|
||
|
import os
|
||
|
from transformers import AutoTokenizer, AutoModel
|
||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||
|
os.environ['HF_HOME'] = "/home/admin/HFHOME"
|
||
|
|
||
|
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||
|
#model_id = "meta-llama/Llama-3.1-70B-Instruct"
|
||
|
|
||
|
# use quantization to lower GPU usage
|
||
|
# 4 bit:
|
||
|
# bnb_config = BitsAndBytesConfig(
|
||
|
# load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
||
|
# )
|
||
|
# 8 bit:
|
||
|
bnb_config = BitsAndBytesConfig(
|
||
|
load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16
|
||
|
)
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
model_id,
|
||
|
torch_dtype=torch.bfloat16,
|
||
|
device_map="auto",
|
||
|
quantization_config=bnb_config
|
||
|
)
|
||
|
terminators = [
|
||
|
tokenizer.eos_token_id,
|
||
|
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||
|
]
|
||
|
model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id
|
||
|
|
||
|
SYS_PROMPT = """
|
||
|
You receive two Persian legal rule texts and analyze them carefully, explain your answer step by step, to see whether these two rules logically conflict with each other.
|
||
|
Finally, state the final conclusion for the presence or absence of conflict with the words "yes" or "no".
|
||
|
"""# Explain your answer step by step.
|
||
|
|
||
|
|
||
|
def format_prompt(SENTENCE1, SENTENCE2):
|
||
|
PROMPT = f"Rule 1: {SENTENCE1}. Rule 2: {SENTENCE2}."
|
||
|
return PROMPT
|
||
|
|
||
|
|
||
|
|
||
|
def generate(formatted_prompt):
|
||
|
formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM
|
||
|
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
||
|
# tell the model to generate
|
||
|
input_ids = tokenizer.apply_chat_template(
|
||
|
messages,
|
||
|
add_generation_prompt=True,
|
||
|
return_tensors="pt"
|
||
|
).to(model.device)
|
||
|
outputs = model.generate(
|
||
|
input_ids,
|
||
|
max_new_tokens=2048,
|
||
|
eos_token_id=terminators,
|
||
|
do_sample=True,
|
||
|
temperature=0.6,
|
||
|
top_p=0.9,
|
||
|
)
|
||
|
response = outputs[0][input_ids.shape[-1]:]
|
||
|
return tokenizer.decode(response, skip_special_tokens=True)
|
||
|
|
||
|
|
||
|
def conflict(sentence1, sentence2):
|
||
|
formatted_prompt = format_prompt(sentence1, sentence2)
|
||
|
return generate(formatted_prompt)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
print('start')
|
||
|
start_time = time.time()
|
||
|
#inputfile = open('./main_rules_lama70B_dataset_02.json', "r", encoding='utf-8')
|
||
|
result = conflict("حسین در حال حاضر در وزارت نفت و انرژی به استخدام دولت در آمده است.",
|
||
|
"حسین الان در دانشگاه دولتی مشغول به تحصیل است و هرکسی که در حال تحصیل باشد، از نظر قانون نمی تواند در دولت استخدام شود")
|
||
|
end_time = time.time()
|
||
|
|
||
|
print("*********************************************************")
|
||
|
print("*********************************************************")
|
||
|
print()
|
||
|
print(result)
|
||
|
print()
|
||
|
print("*********************************************************")
|
||
|
print("*********************************************************")
|
||
|
|
||
|
|
||
|
print(f"elapsed time: {end_time-start_time}")
|
||
|
print("end")
|