llama/elastic_keywords_llama3_3.py
2025-07-13 19:05:59 +03:30

178 lines
5.8 KiB
Python

from html import escape
from lxml import etree
from datetime import datetime
from elasticsearch import Elasticsearch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
from threading import Thread
import torch
import time
from concurrent.futures import ThreadPoolExecutor
import concurrent
import threading
lock = threading.Lock()
#lock1 = threading.Lock()
#from cleantext import clean
#import re
if torch.cuda.is_available():
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# pipe = pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# torch_dtype=torch.float16,
# device_map="auto",
# )
index_name_i = 'semantic_search-v10'
es = Elasticsearch(
"http://127.0.0.1:6900",
# ca_certs="/path/to/http_ca.crt",
basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm")
)
counter = 0
total = 0
remained = 0
id = ''
keywords_count = 15
messages = [ {'role': 'user', 'content': 'لطفا فقط فارسی جواب بدهید'},
]
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="12m", **kwargs):
"""
Helper to iterate ALL values from a single index
Yields all the documents.
"""
global counter
global total
global remained
is_first = True
while True:
# Scroll next
if is_first: # Initialize scroll
# result = es.search(index=index, scroll="12m", **kwargs, body={
# "size": pagesize
# })
result = es.search(index=index, scroll="12m", **kwargs, size=pagesize)
total = result["hits"]["total"]['value']
remained = total
print('total = %d' % total)
is_first = False
else:
# result = es.scroll(body={
# "scroll_id": scroll_id,
# "scroll": scroll_timeout
# })
result = es.scroll( scroll_id = scroll_id, scroll = scroll_timeout )
scroll_id = result["_scroll_id"]
hits = result["hits"]["hits"]
counter += len(hits)
print("progress -> %.2f %% , counte: %d" % ((counter / total)*100, counter))
# Stop after no more docs
if not hits:
break
# Yield each entry
yield from ({"source":hit['_source'], "id":hit['_id']} for hit in hits)
def generateKeywords(text, keywords_count, id):
global remained
try:
lock.acquire()
if torch.cuda.is_available():
print('ok')
messages = [{"role": "user", "content":
'''از "متن" حداقل {} کلیدواژه مهم و پراهمیت را استخراج کن و در قالب لیست به زبان فارسی چاپ کن. "متن": {}
'''.format(keywords_count, text)
}]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)#.to(model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.pad_token_id
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.85,
)
#lock0.release()
response = outputs[0][input_ids.shape[-1]:]
keywords = tokenizer.decode(response, skip_special_tokens=True)
#lock1.acquire()
resp = es.update(index=index_name_i, id=id, doc={"content_keywords-llama3-str": str(keywords)})
remained = remained - 1
#lock1.release()
print("update id = {}, remained items = {} ".format(id,remained))
lock.release()
return id
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print("Exception: " + str(inst))
if __name__ == "__main__":
start_time = time.time()
futures = []
with ThreadPoolExecutor() as executor:
#t0 = time.time()
for mentry in es_iterate_all_documents(es, index_name_i):
try:
entry = mentry['source']
id = mentry['id']
text = entry.get('content','')
full_path = entry.get('other_info', {'full_path':''})['full_path']
text_len = len(text)
if full_path == 'عنوان' or full_path == 'موخره' or full_path == 'امضاء' or text_len == 0:
continue
keywords_count = (text_len / 1000) * 15
if keywords_count < 0.3:
continue
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
futures.append(executor.submit(generateKeywords, text, keywords_count, id))
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print(inst) # __str__ allows args to be printed directly,
# but may be overridden in exception subclasses
print("Exception:=> %s -> %.2f " % (id , counter / total))
print(f"pre end...elapsed time: {time.time()-start_time}")
results = []
for future in concurrent.futures.as_completed(futures):
results.append(future.result())
print('-')
# print('\n'.join(results))
end_time = time.time()
print(f"elapsed time: {end_time-start_time}")