from html import escape from lxml import etree from datetime import datetime from elasticsearch import Elasticsearch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer from threading import Thread import torch import time #from cleantext import clean #import re if torch.cuda.is_available(): model_id = "PartAI/Dorna-Llama3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) # pipe = pipeline( # "text-generation", # model=model, # tokenizer=tokenizer, # torch_dtype=torch.float16, # device_map="auto", # ) index_name_i = 'semantic_search-v10' es = Elasticsearch( "http://127.0.0.1:6900", # ca_certs="/path/to/http_ca.crt", basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm") ) counter = 0 total = 0 id = '' keywords_count = 15 messages = [ {'role': 'user', 'content': 'لطفا فقط فارسی جواب بدهید'}, ] def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="12m", **kwargs): """ Helper to iterate ALL values from a single index Yields all the documents. """ global counter global total is_first = True while True: # Scroll next if is_first: # Initialize scroll # result = es.search(index=index, scroll="12m", **kwargs, body={ # "size": pagesize # }) result = es.search(index=index, scroll="12m", **kwargs, size=pagesize) total = result["hits"]["total"]['value'] print('total = %d' % total) is_first = False else: # result = es.scroll(body={ # "scroll_id": scroll_id, # "scroll": scroll_timeout # }) result = es.scroll( scroll_id = scroll_id, scroll = scroll_timeout ) scroll_id = result["_scroll_id"] hits = result["hits"]["hits"] counter += len(hits) print("progress -> %.2f %% , counte: %d" % ((counter / total)*100, counter)) # Stop after no more docs if not hits: break # Yield each entry yield from ({"source":hit['_source'], "id":hit['_id']} for hit in hits) try: #els_file = open('./elastic-dataset.jsonl', 'w', encoding='utf-8') for mentry in es_iterate_all_documents(es, index_name_i): entry = mentry['source'] id = mentry['id'] #title = entry.get('title','').replace('"', "'").replace('\n', ' ').replace('\r', '') #text = entry.get('clean_content','') text = entry.get('content','') #lkeys = entry.get('content_keywords','') print("%s -> %.2f " % (id , counter / total)) try: full_path = entry.get('other_info', {'full_path':''})['full_path'] text_len = len(text) if full_path == 'عنوان' or full_path == 'موخره' or full_path == 'امضاء' or text_len == 0: continue keywords_count = (text_len / 1000) * 15 if keywords_count < 0.3: continue keywords_count = int(keywords_count) if keywords_count == 0: keywords_count = 1 messages = [{"role": "user", "content": '''از "متن" حداقل {} کلیدواژه مهم و پراهمیت را استخراج کن و در قالب لیست به زبان فارسی چاپ کن. "متن": {} '''.format(keywords_count, text) }] #messages.append(message_new) # prompt = tokenizer.apply_chat_template(messages, # tokenize=False, # add_generation_prompt=True # ) # terminators = [ # tokenizer.eos_token_id, # tokenizer.convert_tokens_to_ids("<|eot_id|>") # ] # outputs = pipe(prompt, # max_new_tokens=256, # do_sample=True, # eos_token_id=terminators, # temperature=0.6, # top_p=0.85 # ) #keywords = outputs[0]["generated_text"] input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] model.generation_config.pad_token_id = tokenizer.pad_token_id start_time = time.time() outputs = model.generate( input_ids, max_new_tokens=256, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.85, ) end_time = time.time() print(f"elapsed time: {end_time-start_time}") response = outputs[0][input_ids.shape[-1]:] keywords = tokenizer.decode(response, skip_special_tokens=True) #print(lkeys) #print(keywords) #print('*'*20) resp = es.update(index=index_name_i, id=id, doc={"content_keywords-llama3-str": str(keywords)}) #messages.pop() except Exception as inst: print(type(inst)) # the exception type print(inst.args) # arguments stored in .args print("Exception: " + str(inst)) # print(inst) # print(id) except Exception as inst: print(type(inst)) # the exception type print(inst.args) # arguments stored in .args print(inst) # __str__ allows args to be printed directly, # but may be overridden in exception subclasses print("%s -> %.2f " % (id , counter / total))