import json from tqdm import tqdm import time import torch import os from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig os.environ['HF_HOME'] = "/home/admin/HFHOME" model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" #model_id = "meta-llama/Llama-3.1-70B-Instruct" # use quantization to lower GPU usage # 4 bit: # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 # ) # 8 bit: bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=bnb_config ) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id SYS_PROMPT = """ You receive two Persian legal rule texts and analyze them carefully, explain your answer step by step, to see whether these two rules logically conflict with each other. Finally, state the final conclusion for the presence or absence of conflict with the words "yes" or "no". """# Explain your answer step by step. def format_prompt(SENTENCE1, SENTENCE2): PROMPT = f"Rule 1: {SENTENCE1}. Rule 2: {SENTENCE2}." return PROMPT def generate(formatted_prompt): formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}] # tell the model to generate input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) outputs = model.generate( input_ids, max_new_tokens=2048, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, ) response = outputs[0][input_ids.shape[-1]:] return tokenizer.decode(response, skip_special_tokens=True) def conflict(sentence1, sentence2): formatted_prompt = format_prompt(sentence1, sentence2) return generate(formatted_prompt) if __name__ == "__main__": print('start') start_time = time.time() #inputfile = open('./main_rules_lama70B_dataset_02.json', "r", encoding='utf-8') result = conflict("حسین در حال حاضر در وزارت نفت و انرژی به استخدام دولت در آمده است.", "حسین الان در دانشگاه دولتی مشغول به تحصیل است و هرکسی که در حال تحصیل باشد، از نظر قانون نمی تواند در دولت استخدام شود") end_time = time.time() print("*********************************************************") print("*********************************************************") print() print(result) print() print("*********************************************************") print("*********************************************************") print(f"elapsed time: {end_time-start_time}") print("end")