llama/elastic_keywords_llama3_5.py
2025-07-13 19:05:59 +03:30

196 lines
6.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
این فایل با نرمالایزر هضم کار می کند
"""
from html import escape
from lxml import etree
from datetime import datetime
from elasticsearch import Elasticsearch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
from threading import Thread
import torch
import time
from concurrent.futures import ThreadPoolExecutor
import concurrent
import threading
import json
import os.path
import os
import normalizer
from funcs import write_to_json, read_from_json
#lock = threading.Lock()
#lock1 = threading.Lock()
#from cleantext import clean
#import re
address = os.getcwd()
sections_list = read_from_json(address + '/data/clean_sections_11k.json') # Main File
destination_ids = """qs211587
qs211591
qs882217
qs905974
qs2574729
qs1060308
qs2052110
qs1421241
qs2051993""".split()
if torch.cuda.is_available():
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# pipe = pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# torch_dtype=torch.float16,
# device_map="auto",
# )
index_name_i = 'semantic_search-v10'
es = Elasticsearch(
"http://127.0.0.1:6900",
# ca_certs="/path/to/http_ca.crt",
basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm")
)
counter = 0
total = 0
remained = 0
id = ''
keywords_count = 15
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="12m", **kwargs):
"""
Helper to iterate ALL values from a single index
Yields all the documents.
"""
global counter
global total
global remained
is_first = True
while True:
# Scroll next
if is_first: # Initialize scroll
# result = es.search(index=index, scroll="12m", **kwargs, body={
# "size": pagesize
# })
result = es.search(index=index, scroll="12m", **kwargs, size=pagesize)
total = result["hits"]["total"]['value']
remained = total
print('total = %d' % total)
is_first = False
else:
# result = es.scroll(body={
# "scroll_id": scroll_id,
# "scroll": scroll_timeout
# })
result = es.scroll( scroll_id = scroll_id, scroll = scroll_timeout )
scroll_id = result["_scroll_id"]
hits = result["hits"]["hits"]
counter += len(hits)
print("progress -> %.2f %% , count: %d" % ((counter / total)*100, counter))
# Stop after no more docs
if not hits:
break
# Yield each entry
yield from ({"source":hit['_source'], "id":hit['_id']} for hit in hits)
def generateKeywords(text):
global remained
try:
keywords_count = (len(text) / 1000) * 15
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
messages = [{"role": "system", "content": "تو یک وکیل حقوق دان هستی و باید بتوانی متن های قانونی و حقوقی را بدون تغییر اصطلاحات فنی، به صورتی توضیح دهی که افراد غیر حقوق دان، معنای متن را درک کنند. " },
{"role": "user", "content":
'''از "متن" حداقل {} کلیدواژه مهم و پراهمیت را استخراج کن که حداقل بین 1 تا 5 کلمه داشته باشد و کلیدواژه ها را در قالب لیست به زبان فارسی چاپ کن و هر کلید واژه را در یک خط جدید قرار بده و هیچ گونه توضیحی در ابتدا یا انتهای پاسخ، اضافه نکن.
هر کلیدواژه دارای یک شماره ترتیبی در ابتدای آن باشد. کلیدواژه ها، دقیقا در متن موجود باشد. نام سازمان ها و نهادها و اشخاص حقوقی، حتما به عنوان کلید واژه درنظر گرفته شود. هیچ کلیدواژه ای، فعل یا حرف اضافه نباشد و فقط شامل اسامی تک کلمه ای یا کلماتی باشد که به هم اضافه شده اند. هیچ کلیدواژه ای نباید با حرف اضافه تمام شود. کلیدواژه ها برابر با ماده، بند، یا تبصره و تاریخ ها نباشند.
"متن": {}
'''.format(keywords_count, text)
}]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.pad_token_id
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.85,
)
#lock0.release()
response = outputs[0][input_ids.shape[-1]:]
keywords = tokenizer.decode(response, skip_special_tokens=True)
#lock1.acquire()
# resp = es.update(index=index_name_i, id=id, doc={"content_keywords-llama3-str": str(keywords)})
return keywords
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print("Exception: " + str(inst))
if __name__ == "__main__":
start_time = time.time()
try:
keywords_dict = []
count = 1
for content_item in sections_list:
id = content_item['id']
if not id in destination_ids:
continue
content = content_item['content']
# content = normalizer.cleaning(content)
keywords = generateKeywords(content)
print("section " + str(count) + "/" + str(len(sections_list)) + " keyword extracting ... ")
keywords_dict.append({
'id':id,
'keywords':keywords
})
count+=1
write_to_json(keywords_dict, address+"/data/sections_kw_11ktest_hazm_nimfasele.json")
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
end_time = time.time()
print(end_time)
operation_time = (int(end_time-start_time)/60)/60
print(f"elapsed time: {operation_time} hours")
print(f" Finished!!! ")