RD_relation/relation/complete_sections_170K.py

205 lines
7.5 KiB
Python
Raw Normal View History

2025-01-19 15:46:21 +00:00
"""
این فایل ویژگی شناسهالد را از الستیک می خواند و به جیسون 170 هزار ماده اصلی اضافه می کند"""
from html import escape
from lxml import etree
from datetime import datetime
from elasticsearch import Elasticsearch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
from threading import Thread
import torch
import time
from concurrent.futures import ThreadPoolExecutor
import concurrent
import threading
import json
import os.path
import os
from general_functions import normalize_content
from funcs import write_to_json, read_from_json
#lock = threading.Lock()
#lock1 = threading.Lock()
#from cleantext import clean
#import re
from normalizer import Normalizer
from tokenizer import *
_normalizer = Normalizer(date_normalizing_needed=True)
address = os.getcwd()
# sections_list = read_from_json(address + '/data/clean_sections_11k.json') # Main File
sections_list = read_from_json('../data/clean_sections_11k.json') # Main File
# destination_ids = """qs211587
# qs211591
# qs882217
# qs905974
# qs2574729
# qs1060308
# qs2052110
# qs1421241
# qs2051993""".split()
if torch.cuda.is_available():
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# pipe = pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# torch_dtype=torch.float16,
# device_map="auto",
# )
index_name_i = 'mj_qa_section-v02'
es = Elasticsearch(
"http://127.0.0.1:6900",
# ca_certs="/path/to/http_ca.crt",
basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm")
)
counter = 0
total = 0
remained = 0
id = ''
keywords_count = 15
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="12m", **kwargs):
"""
Helper to iterate ALL values from a single index
Yields all the documents.
"""
global counter
global total
global remained
is_first = True
while True:
# Scroll next
if is_first: # Initialize scroll
# result = es.search(index=index, scroll="12m", **kwargs, body={
# "size": pagesize
# })
result = es.search(index=index, scroll="12m", **kwargs, size=pagesize)
total = result["hits"]["total"]['value']
remained = total
print('total = %d' % total)
is_first = False
else:
# result = es.scroll(body={
# "scroll_id": scroll_id,
# "scroll": scroll_timeout
# })
result = es.scroll( scroll_id = scroll_id, scroll = scroll_timeout )
scroll_id = result["_scroll_id"]
hits = result["hits"]["hits"]
counter += len(hits)
print("progress -> %.2f %% , count: %d" % ((counter / total)*100, counter))
# Stop after no more docs
if not hits:
break
# Yield each entry
yield from ({"source":hit['_source'], "id":hit['_id']} for hit in hits)
def generateKeywords(text):
global remained
try:
keywords_count = (len(text) / 1000) * 15
keywords_count = int(keywords_count)
if keywords_count == 0:
keywords_count = 1
messages = [{"role": "system", "content": "تو یک وکیل حقوق دان هستی و باید بتوانی متن های قانونی و حقوقی را بدون تغییر اصطلاحات فنی، به صورتی توضیح دهی که افراد غیر حقوق دان، معنای متن را درک کنند. " },
{"role": "user", "content":
'''از "متن" حداقل {} کلیدواژه مهم و پراهمیت را استخراج کن و کلیدواژه ها را در قالب لیست به زبان فارسی چاپ کن و هر کلید واژه را در یک خط جدید قرار بده و هیچ گونه توضیحی در ابتدا یا انتهای پاسخ، اضافه نکن.
هر کلیدواژه دارای یک شماره ترتیبی در ابتدای آن باشد. کلیدواژه ها، دقیقا در متن موجود باشد. بسیار مهم و ضروری است که طول هر کلیدواژه حداقل دو توکن داشته باشد و کلیدواژه ی یک توکنی قابل قبول نیست. نام سازمان ها و نهادها و اشخاص حقوقی، حتما به عنوان کلیدواژه درنظر گرفته شود. هیچ کلیدواژه ای، فعل یا حرف اضافه نباشد و فقط شامل اسم هایی باشد که به هم اضافه شده اند. هیچ کلیدواژه ای نباید با حرف اضافه یا حرف «و» تمام شود. ضروری است که کلیدواژه ها شامل ماده، بند، تبصره یا تاریخ ها نباشند.
"متن": {}
'''.format(keywords_count, text)
}]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
model.generation_config.pad_token_id = tokenizer.pad_token_id
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.85,
)
#lock0.release()
response = outputs[0][input_ids.shape[-1]:]
keywords = tokenizer.decode(response, skip_special_tokens=True)
#lock1.acquire()
# resp = es.update(index=index_name_i, id=id, doc={"content_keywords-llama3-str": str(keywords)})
return keywords
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print("Exception: " + str(inst))
if __name__ == "__main__":
start_time = time.time()
print("start_time: "+str(datetime.now()))
try:
keywords_dict = []
count = 1
for content_item in sections_list:
id = content_item['id']
# if not id in destination_ids:
# continue
content = content_item['content']
content_len = len(content.split())
# کنارگذاشتن محتواهای با حجم زیاد
if content_len > 2000:
print("too long content " + str(id))
continue
content = _normalizer.sub_alphabets(content)
keywords = generateKeywords(content)
print("section " + str(count) + "/" + str(len(sections_list)) + " keyword extracting ... ")
keywords_dict.append({
'id':id,
'keywords':keywords
})
count+=1
write_to_json(keywords_dict, "../data/sections_kw_11k_new.json")
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
end_time = time.time()
print("end_time: "+ str(datetime.now()))
operation_time = (int(end_time-start_time)/60)/60
print(f"elapsed time: {operation_time} hours")
print(f" Finished!!! ")