import json from tqdm import tqdm import time import torch import os from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig os.environ['HF_HOME'] = "/home/admin/HFHOME" #model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" model_id = "meta-llama/Llama-3.1-70B-Instruct" # use quantization to lower GPU usage # 4 bit: # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 # ) # 8 bit: bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=bnb_config ) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] model.generation_config.pad_token_id = tokenizer.eos_token_id #tokenizer.pad_token_id SYS_PROMPT = """You receive a Persian legal text and extract from it the keywords that are most important. And you don't need to provide explanations or additional text. Put each keyword on a single line." """# Explain your answer step by step. def format_prompt(SENTENCE): PROMPT = f"Persian legal text: {SENTENCE}." return PROMPT def generate(formatted_prompt): formatted_prompt = formatted_prompt[:50000] # to avoid GPU OOM messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}] # tell the model to generate input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) outputs = model.generate( input_ids, max_new_tokens=2048, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, ) response = outputs[0][input_ids.shape[-1]:] return tokenizer.decode(response, skip_special_tokens=True) def get_rules(sentence): formatted_prompt = format_prompt(sentence) rules = generate(formatted_prompt).split('\n') result = [r.strip() for r in rules if r.strip()] return result if __name__ == "__main__": print('start') start_time = time.time() inputfile = open('./data/main_classes_dataset_03.json', "r", encoding='utf-8') data = json.load(inputfile) inputfile.close() counter = 1 for c in tqdm(data): for item in tqdm(data[c]): content = item['content'] item['keywords'] = get_rules(content) print(f"section {counter} ...") counter += 1 outputfile = open('./data/main_keywords_lama70B_dataset_03.json', "w", encoding='utf-8') outputfile.write(json.dumps(data, ensure_ascii=False, indent = 4)) outputfile.close() end_time = time.time() print(f"elapsed time: {end_time-start_time}") print("end") exit() """ system prompt version 2 for test: You are a lawyer and you must be able to explain legal texts without changing technical terms in a way that non-lawyers can understand the meaning of the text. user prompt version 2 for test: Extract at least {} important and significant key phrases from the "text" and print the key phrases in the form of a list in Persian and put each key phrase on a new line and do not add any explanation at the beginning or end of the answer. Each key phrase has a sequential number at the beginning. The key phrases must be present exactly in the text. It is very important and essential that the length of each key phrase has at least two tokens and a single-token key phrase is not acceptable. I emphasize that no key phrase should have only one token. The names of organizations, institutions and legal entities must be considered as key phrases. No key phrase should be a verb or a preposition and should only include nouns that are added together. No key phrase should end with a preposition or the letter "و". It is essential that key phrases do not include "ماده", "تبصره", "بند" or "تاریخ ها". """