ai_dataset/main_qa_data/data_selector.py

265 lines
7.6 KiB
Python
Raw Permalink Normal View History

2024-09-17 16:57:26 +00:00
"""
این کد برای انتخاب سکشن های اصلی که بند و عنوان و موخره و امضا نیستند
و معتبر هستند و فقط از مصوبات مجلس هستند ایجاد شده است که شامل حدود 9 هزار سکشن می شود.
"""
from html import escape
from lxml import etree
from datetime import datetime
from elasticsearch import Elasticsearch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
from threading import Thread
import torch
import time
from concurrent.futures import ThreadPoolExecutor
import concurrent
import threading
import json
import numpy as np
from funcs import write_to_json, write_to_excel
import os
index_name_i = 'mj_qa_section-v02'# semantic_search-v10
es = Elasticsearch(
"http://127.0.0.1:6900",
basic_auth=("elastic", "SG*7eGwg+KG2_*-1_mMm")
)
counter = 0
total = 0
remained = 0
id = ''
keywords_count = 15
body_query = "{'query': {'match_all': {}}}"
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="25m", **kwargs):
"""
Helper to iterate ALL values from a single index
Yields all the documents.
"""
global counter
global total
is_first = True
while True:
# Scroll next
if is_first: # Initialize scroll
result = es.search(
index=index,
scroll="2m",
**kwargs,
size=pagesize,
body={
"query": {
"bool": {
"must": [
{
"bool": {
"must_not": [
{
"wildcard": {
"other_info.full_path": "بند*"
}
},
{
"match": {
"other_info.full_path": "موخره"
}
},
{
"match": {
"other_info.full_path": "امضاء"
}
},
{
"match": {
"other_info.full_path": "عنوان"
}
}
]
}
},
{
"bool": {
"filter": {
"bool": {
"must": [
{
"term": {
"qanon_etebar": "معتبر"
}
},
{
"term": {
"title_type": "عادی"
}
},
{
"term": {
"ts_ref.keyword": "مجلس شورای اسلامی"
}
},
{
"term": {
"sub_type": "عادی"
}
}
]
}
}
}
}
]
}
},
"sort": {
"sort_date_timestamp": {
"order": "desc"
}
},
"track_total_hits": True,
"aggs": {
"total_collapse": {
"cardinality": {
"field": "qanon_id"
}
}
}
}
)
total = result["hits"]["total"]["value"]
print("total = %d" % total)
is_first = False
else:
result = es.scroll(scroll_id=scroll_id, scroll=scroll_timeout)
scroll_id = result["_scroll_id"]
hits = result["hits"]["hits"]
counter += len(hits)
print("progress -> %.2f %%" % ((counter / total) * 100))
# Stop after no more docs
if not hits:
break
# Yield each entry
yield from ({"source": hit["_source"], "id": hit["_id"]} for hit in hits)
def add_section(section):
data = ({
"id": section["id"],
"qanon_id": section["qanon_id"],
"content": section["content"],
"main_topic": section["tcode_main_old"][0],
"all_topics": section["tcode_main_old"],
"ts_year": section["ts_year"],
"state_etebar": section["state_etebar"],
"ners": section["ners_v1"]
})
return data
if __name__ == "__main__":
base_address = os.getcwd() # debugger
#base_address = "/home/gpu/tnlp/jokar/llama" # terminal
# json_address_15k_sections = base_address + "/data/qa_sections_15k.json"
start_time = time.time()
all_sections = es_iterate_all_documents(es, index_name_i)
all_sections_arr = []
for mentry in all_sections:
section_id = mentry["id"]
source = mentry["source"]
all_sections_arr.append([section_id, source])
np_sections_arr = np.array(all_sections_arr)
selected_sections = []
index = -1
x = 0
try:
np_topics = []
for i, dataline in enumerate(np_sections_arr):
# if i == 813:
# pass
line = dataline[1]
id = line['id']
content = line['content']
state_etebar = line["state_etebar"]
if not state_etebar == "معتبر": continue
if not content: continue
if content.__contains__("به‌ شرح پیوست")\
or content.startswith("تبصره"):
continue
if content[len(content)-1] == ":": continue
if len(content.split()) < 30 or len(content.split()) > 150: continue
try:
main_topic = line["tcode_main_old"][0]# tcode_main_old
except:
x += 1
continue
if len(np_topics) > 0:
try:
# result = np.where(np_topics[:, 0] == main_topic)[0]
np_topics2 = np.array(np_topics)
result = np.where(np_topics2[:, 0] == main_topic)[0][0]
except Exception as e:
result = -1
if result != -1:
repeat_factor = (np_topics[result])[1]
if repeat_factor == 10:
continue
else:
(np_topics[result])[1] += 1
selected_sections.append(add_section(line))
#np_topics.append([main_topic, repeat_factor+1]) #repeat_factor+1
else:
np_topics.append([main_topic, 1]) #repeat_factor+1
selected_sections.append(add_section(line))
else:
selected_sections.append(add_section(line))
np_topics.append([main_topic, 1])
print(np_topics)
print(i+1)
except Exception as inst:
print(type(inst)) # the exception type
print(inst.args) # arguments stored in .args
print(inst) # __str__ allows args to be printed directly,
# but may be overridden in exception subclasses
print("Exception:=> %s -> %.2f " % (id , counter / total))
print(len(selected_sections))
path = "./data/selected_sections.json" # os.getcwd() +
write_to_json(selected_sections, path)
excel_path = "./data/selected_sections.xlsx"
write_to_excel(selected_sections, excel_path)
end_time = time.time()
print(f"elapsed time: {end_time-start_time}")
print(" *** finished! *** ")