186 lines
6.4 KiB
Python
186 lines
6.4 KiB
Python
""" این فایل، یک جیسون که حاوی لیستی از متن است را دریافت می کند
|
||
و به کمک لاما3، هر متن را با تعدادی جمله ساده تر، بازنویسی می کند"""
|
||
from html import escape
|
||
from lxml import etree
|
||
from datetime import datetime
|
||
from elasticsearch import Elasticsearch
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
|
||
from threading import Thread
|
||
import torch
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import concurrent
|
||
import threading
|
||
import json
|
||
import os.path
|
||
|
||
from funcs import read_from_json, write_to_json
|
||
import os
|
||
|
||
if torch.cuda.is_available():
|
||
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
|
||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||
|
||
index_name_i = 'semantic_search-v10'
|
||
|
||
|
||
es = Elasticsearch(
|
||
"http://127.0.0.1:6900",
|
||
# ca_certs="/path/to/http_ca.crt",
|
||
basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm")
|
||
)
|
||
|
||
counter = 0
|
||
total = 0
|
||
remained = 0
|
||
id = ''
|
||
keywords_count = 15
|
||
|
||
|
||
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="12m", **kwargs):
|
||
"""
|
||
Helper to iterate ALL values from a single index
|
||
Yields all the documents.
|
||
"""
|
||
global counter
|
||
global total
|
||
global remained
|
||
is_first = True
|
||
|
||
while True:
|
||
|
||
# Scroll next
|
||
if is_first: # Initialize scroll
|
||
# result = es.search(index=index, scroll="12m", **kwargs, body={
|
||
# "size": pagesize
|
||
# })
|
||
result = es.search(index=index, scroll="12m", **kwargs, size=pagesize)
|
||
total = result["hits"]["total"]['value']
|
||
remained = total
|
||
print('total = %d' % total)
|
||
is_first = False
|
||
else:
|
||
# result = es.scroll(body={
|
||
# "scroll_id": scroll_id,
|
||
# "scroll": scroll_timeout
|
||
# })
|
||
result = es.scroll( scroll_id = scroll_id, scroll = scroll_timeout )
|
||
scroll_id = result["_scroll_id"]
|
||
hits = result["hits"]["hits"]
|
||
counter += len(hits)
|
||
print("progress -> %.2f %% , count: %d" % ((counter / total)*100, counter))
|
||
# Stop after no more docs
|
||
if not hits:
|
||
break
|
||
# Yield each entry
|
||
yield from ({"source":hit['_source'], "id":hit['_id']} for hit in hits)
|
||
|
||
def generateKeywords(text):
|
||
global remained
|
||
try:
|
||
sen_count = (len(text) / 1000) * 15
|
||
sen_count = int(sen_count)
|
||
if sen_count == 0:
|
||
sen_count = 1
|
||
messages = [{"role": "system", "content": "تو یک وکیل یا حقوق دان هستی و در پاسخ دقت قانونی داشته باشد " },{"role": "user", "content":
|
||
f"متن زیر را در قالب {sen_count} جمله جداگانه، ساده و روان به زبان فارسی، بازنویسی کن و بین دو * قرار بده و هیچ گونه توضیحی در ابتدا یا انتهای پاسخ، اضافه نکن\n متن:{text}"
|
||
}]
|
||
|
||
input_ids = tokenizer.apply_chat_template(
|
||
messages,
|
||
add_generation_prompt=True,
|
||
return_tensors="pt"
|
||
).to(model.device)
|
||
|
||
terminators = [
|
||
tokenizer.eos_token_id,
|
||
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||
]
|
||
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
||
|
||
|
||
outputs = model.generate(
|
||
input_ids,
|
||
max_new_tokens=256,
|
||
eos_token_id=terminators,
|
||
do_sample=True,
|
||
temperature=0.6,
|
||
top_p=0.85,
|
||
)
|
||
#lock0.release()
|
||
response = outputs[0][input_ids.shape[-1]:]
|
||
keywords = tokenizer.decode(response, skip_special_tokens=True)
|
||
#lock1.acquire()
|
||
# resp = es.update(index=index_name_i, id=id, doc={"content_keywords-llama3-str": str(keywords)})
|
||
|
||
|
||
return keywords
|
||
|
||
except Exception as inst:
|
||
print(type(inst)) # the exception type
|
||
print(inst.args) # arguments stored in .args
|
||
print("Exception: " + str(inst))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
#base_address = os.getcwd() + "/llama" # debugger
|
||
base_address = "/home/gpu/tnlp/jokar/llama" # terminal
|
||
|
||
json_address_85_sections = base_address + "/data/qa_sections_85.json"
|
||
json_address_3k_sections = base_address + "/data/qa_sections_3k.json"
|
||
|
||
datalist_85 = read_from_json(json_address_85_sections)
|
||
datalist_3k = read_from_json(json_address_3k_sections)
|
||
|
||
start_time = time.time()
|
||
|
||
result_list = []
|
||
try:
|
||
# part = datalist_3k[:600]
|
||
# datalist_85 = datalist_85[:20]
|
||
for i, line in enumerate(datalist_85):
|
||
print(i+1)
|
||
id = line['id']
|
||
content = line['content']
|
||
result = generateKeywords(content)
|
||
print("++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
||
print(result)
|
||
print("++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
||
|
||
result_list.append({
|
||
"id": id,
|
||
"content": content,
|
||
"result": result
|
||
})
|
||
|
||
destination_address = base_address + "/data/simplized_sentences_85_02.json"
|
||
write_to_json(result_list, destination_address)
|
||
|
||
# for j, line in enumerate(part):
|
||
# print(j+1+85)
|
||
# id = line['id']
|
||
# content = line['content']
|
||
# result = generateKeywords(content)
|
||
|
||
# result_list.append({
|
||
# "id": id,
|
||
# "content": content,
|
||
# "result": result
|
||
# })
|
||
|
||
destination_address = base_address + "/data/simplized_sentences_3k.json"
|
||
write_to_json(result_list, destination_address)
|
||
|
||
except Exception as inst:
|
||
print(type(inst)) # the exception type
|
||
print(inst.args) # arguments stored in .args
|
||
print(inst) # __str__ allows args to be printed directly,
|
||
# but may be overridden in exception subclasses
|
||
print("Exception:=> %s -> %.2f " % (id , counter / total))
|
||
|
||
|
||
end_time = time.time()
|
||
print(f"elapsed time: {end_time-start_time}")
|
||
print(" *** finished! *** ") |