llama/llama3_sentence.py
2025-07-13 19:05:59 +03:30

186 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

""" این فایل، یک جیسون که حاوی لیستی از متن است را دریافت می کند
و به کمک لاما3، هر متن را با تعدادی جمله ساده تر، بازنویسی می کند"""
from html import escape
from lxml import etree
from datetime import datetime
from elasticsearch import Elasticsearch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
from threading import Thread
import torch
import time
from concurrent.futures import ThreadPoolExecutor
import concurrent
import threading
import json
import os.path
from funcs import read_from_json, write_to_json
import os
if torch.cuda.is_available():
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
index_name_i = 'semantic_search-v10'
es = Elasticsearch(
"http://127.0.0.1:6900",
# ca_certs="/path/to/http_ca.crt",
basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm")
)
counter = 0
total = 0
remained = 0
id = ''
keywords_count = 15
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="12m", **kwargs):
"""
Helper to iterate ALL values from a single index
Yields all the documents.
"""
global counter
global total
global remained
is_first = True
while True:
# Scroll next
if is_first: # Initialize scroll
# result = es.search(index=index, scroll="12m", **kwargs, body={
# "size": pagesize
# })
result = es.search(index=index, scroll="12m", **kwargs, size=pagesize)
total = result["hits"]["total"]['value']
remained = total
print('total = %d' % total)
is_first = False
else:
# result = es.scroll(body={
# "scroll_id": scroll_id,
# "scroll": scroll_timeout
# })
result = es.scroll( scroll_id = scroll_id, scroll = scroll_timeout )
scroll_id = result["_scroll_id"]
hits = result["hits"]["hits"]
counter += len(hits)
print("progress -> %.2f %% , count: %d" % ((counter / total)*100, counter))
# Stop after no more docs
if not hits:
break
# Yield each entry
yield from ({"source":hit['_source'], "id":hit['_id']} for hit in hits)
def generateKeywords(text):
global remained
try:
sen_count = (len(text) / 1000) * 15
sen_count = int(sen_count)
if sen_count == 0:
sen_count = 1
messages = [{"role": "system", "content": "تو یک وکیل یا حقوق دان هستی و در پاسخ دقت قانونی داشته باشد " },{"role": "user", "content":
f"متن زیر را در قالب {sen_count} جمله جداگانه، ساده و روان به زبان فارسی، بازنویسی کن و بین دو * قرار بده و هیچ گونه توضیحی در ابتدا یا انتهای پاسخ، اضافه نکن\n متن:{text}"
}]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.pad_token_id
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.85,
)
#lock0.release()
response = outputs[0][input_ids.shape[-1]:]
keywords = tokenizer.decode(response, skip_special_tokens=True)
#lock1.acquire()
# resp = es.update(index=index_name_i, id=id, doc={"content_keywords-llama3-str": str(keywords)})
return keywords
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print("Exception: " + str(inst))
if __name__ == "__main__":
#base_address = os.getcwd() + "/llama" # debugger
base_address = "/home/gpu/tnlp/jokar/llama" # terminal
json_address_85_sections = base_address + "/data/qa_sections_85.json"
json_address_3k_sections = base_address + "/data/qa_sections_3k.json"
datalist_85 = read_from_json(json_address_85_sections)
datalist_3k = read_from_json(json_address_3k_sections)
start_time = time.time()
result_list = []
try:
# part = datalist_3k[:600]
# datalist_85 = datalist_85[:20]
for i, line in enumerate(datalist_85):
print(i+1)
id = line['id']
content = line['content']
result = generateKeywords(content)
print("++++++++++++++++++++++++++++++++++++++++++++++++++++")
print(result)
print("++++++++++++++++++++++++++++++++++++++++++++++++++++")
result_list.append({
"id": id,
"content": content,
"result": result
})
destination_address = base_address + "/data/simplized_sentences_85_02.json"
write_to_json(result_list, destination_address)
# for j, line in enumerate(part):
# print(j+1+85)
# id = line['id']
# content = line['content']
# result = generateKeywords(content)
# result_list.append({
# "id": id,
# "content": content,
# "result": result
# })
destination_address = base_address + "/data/simplized_sentences_3k.json"
write_to_json(result_list, destination_address)
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print(inst) # __str__ allows args to be printed directly,
# but may be overridden in exception subclasses
print("Exception:=> %s -> %.2f " % (id , counter / total))
end_time = time.time()
print(f"elapsed time: {end_time-start_time}")
print(" *** finished! *** ")