create rag in git
This commit is contained in:
commit
3bb5cc4235
271
RAG_llama3_chatbot_elastic_embd.py
Normal file
271
RAG_llama3_chatbot_elastic_embd.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
#pip install -q datasets sentence-transformers faiss-cpu accelerate
|
||||
#from sentence_transformers import SentenceTransformer, util
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
#from sklearn.metrics.pairwise import cosine_similarity
|
||||
#from sentence_transformers import SentenceTransformer
|
||||
import numpy as np
|
||||
import time
|
||||
import os.path
|
||||
# import pickle
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
from datasets import Dataset, load_from_disk
|
||||
import tensorflow as tf
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
from funcs import read_from_json, read_dict_from_json
|
||||
|
||||
print('start')
|
||||
|
||||
start_time = time.time()
|
||||
#
|
||||
#----
|
||||
import hazm
|
||||
from cleantext import clean
|
||||
import re
|
||||
|
||||
def cleanhtml(raw_html):
|
||||
cleanr = re.compile('<.*?>')
|
||||
cleantext = re.sub(cleanr, '', raw_html)
|
||||
return cleantext
|
||||
|
||||
normalizer = hazm.Normalizer()
|
||||
# removing wierd patterns
|
||||
wierd_pattern = re.compile("["
|
||||
u"\U0001F600-\U0001F64F" # emoticons
|
||||
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||
u"\U00002702-\U000027B0"
|
||||
u"\U000024C2-\U0001F251"
|
||||
u"\U0001f926-\U0001f937"
|
||||
u'\U00010000-\U0010ffff'
|
||||
u"\u200d"
|
||||
u"\u2640-\u2642"
|
||||
u"\u2600-\u2B55"
|
||||
u"\u23cf"
|
||||
u"\u23e9"
|
||||
u"\u231a"
|
||||
u"\u3030"
|
||||
u"\ufe0f"
|
||||
u"\u2069"
|
||||
u"\u2066"
|
||||
#u"\u200c"
|
||||
u"\u2068"
|
||||
u"\u2067"
|
||||
"]+", flags=re.UNICODE)
|
||||
|
||||
def cleaning(text):
|
||||
text = text.strip()
|
||||
|
||||
# regular cleaning
|
||||
# text = clean(text,
|
||||
# fix_unicode=True,
|
||||
# to_ascii=False,
|
||||
# lower=True,
|
||||
# no_line_breaks=True,
|
||||
# no_urls=True,
|
||||
# no_emails=True,
|
||||
# no_phone_numbers=True,
|
||||
# no_numbers=False,
|
||||
# no_digits=False,
|
||||
# no_currency_symbols=True,
|
||||
# no_punct=False,
|
||||
# replace_with_url="",
|
||||
# replace_with_email="",
|
||||
# replace_with_phone_number="",
|
||||
# replace_with_number="",
|
||||
# replace_with_digit="0",
|
||||
# replace_with_currency_symbol="",
|
||||
# )
|
||||
text = clean(text,
|
||||
extra_spaces = True,
|
||||
lowercase = True
|
||||
)
|
||||
|
||||
# cleaning htmls
|
||||
text = cleanhtml(text)
|
||||
|
||||
# normalizing
|
||||
#normalizer = hazm.Normalizer()
|
||||
text = normalizer.normalize(text)
|
||||
|
||||
|
||||
text = wierd_pattern.sub(r'', text)
|
||||
|
||||
# removing extra spaces, hashtags
|
||||
text = re.sub("#", "", text)
|
||||
text = re.sub("\s+", " ", text)
|
||||
|
||||
return text
|
||||
|
||||
#---
|
||||
#dataset = load_dataset("not-lain/wikipedia")
|
||||
|
||||
# dataset # Let's checkout our dataset
|
||||
# >>> DatasetDict({
|
||||
# train: Dataset({
|
||||
# features: ['id', 'url', 'title', 'text'],
|
||||
# num_rows: 3000
|
||||
# })
|
||||
# })
|
||||
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
|
||||
model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
|
||||
|
||||
if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
|
||||
print('model files is not exists in model path directory.')
|
||||
exit(0)
|
||||
|
||||
# Mean Pooling - Take attention mask into account for correct averaging
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
# Load model from HuggingFace Hub
|
||||
tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
model_bert = AutoModel.from_pretrained(model_name_or_path)
|
||||
|
||||
def encode(sentences):
|
||||
# Tokenize sentences
|
||||
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = model_bert(**encoded_input)
|
||||
|
||||
# Perform pooling. In this case, max pooling.
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
|
||||
#print("Sentence embeddings:")
|
||||
#print(sentence_embeddings)
|
||||
return sentence_embeddings
|
||||
###
|
||||
#----
|
||||
text_arr = []
|
||||
all_sections = []
|
||||
"""
|
||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
چون بر اساس اولین خوانش امبدینگ سکشن ها فایل مخزن برای جستجو ایجاد شده است ، نیاز به خوانش مجدد نیست(به همین دلیل حلقه زیر کامنت شده است) مگر اینکه امبدینگ جدیدی تولید شود
|
||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"""
|
||||
# for i in range(1,111):#111
|
||||
# file_address = f"./data/mj_qa_section_ai/mj_qa_section_ai{i}.json"
|
||||
# sections = read_from_json(file_address)
|
||||
# all_sections.extend(sections)
|
||||
# print(file_address)
|
||||
|
||||
|
||||
#---
|
||||
if not os.path.exists('mj/caches_test/embeddings_data'):
|
||||
corpus_embeddings = []
|
||||
for item in tqdm(all_sections):
|
||||
id = item['_id']
|
||||
source = item['_source']
|
||||
content = source['content']
|
||||
# embedding1 = encode(content)
|
||||
embedding = torch.tensor([source['embeddings']])
|
||||
# embedding = torch.tensor(source['embeddings']).reshape(-1)
|
||||
|
||||
corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
|
||||
data = Dataset.from_list(corpus_embeddings)
|
||||
data.save_to_disk('mj/caches_test/embeddings_data')
|
||||
else:
|
||||
data = load_from_disk('mj/caches_test/embeddings_data')
|
||||
|
||||
|
||||
def remove_dem(example):
|
||||
example["embeddings"] = example["embeddings"][0]
|
||||
return example
|
||||
|
||||
if os.path.exists('mj/caches_test/embeddings_index.faiss'):
|
||||
data.load_faiss_index('embeddings', 'mj/caches_test/embeddings_index.faiss')
|
||||
else:
|
||||
udata = data.map(remove_dem)
|
||||
udata = udata.add_faiss_index("embeddings")
|
||||
udata.save_faiss_index('embeddings', 'mj/caches_test/embeddings_index.faiss')
|
||||
data = udata
|
||||
|
||||
def search(query: str, k: int = 3 ):
|
||||
"""a function that embeds a new query and returns the most probable results"""
|
||||
embedded_query = encode(query) # embed new query
|
||||
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
|
||||
"embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
|
||||
k=k # get only top k results
|
||||
)
|
||||
return scores, retrieved_examples
|
||||
|
||||
|
||||
# search for word anarchy and get the best 4 matching values from the dataset
|
||||
# scores , result = search("اداره کار و رفاه", 4 )
|
||||
# print(result['clean_content'])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
|
||||
# use quantization to lower GPU usage
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
terminators = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||||
]
|
||||
|
||||
SYS_PROMPT = """You are an assistant for answering questions.
|
||||
You are given the extracted parts of a long document and a question. Provide a conversational answer.
|
||||
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
|
||||
|
||||
|
||||
def format_prompt(prompt,retrieved_documents,k):
|
||||
"""using the retrieved documents we will prompt the model to generate our responses"""
|
||||
PROMPT = f"Question:{prompt}\nContext:"
|
||||
for idx in range(k) :
|
||||
PROMPT+= f"{retrieved_documents['clean_content'][idx]}\n"
|
||||
return PROMPT
|
||||
|
||||
def generate(formatted_prompt):
|
||||
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
|
||||
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
||||
# tell the model to generate
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=1024,
|
||||
eos_token_id=terminators,
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.9,
|
||||
)
|
||||
response = outputs[0][input_ids.shape[-1]:]
|
||||
return tokenizer.decode(response, skip_special_tokens=True)
|
||||
|
||||
def rag_chatbot(prompt:str,k:int=2):
|
||||
scores , retrieved_documents = search(prompt, k)
|
||||
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
|
||||
for item in retrieved_documents['id']:
|
||||
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
|
||||
return generate(formatted_prompt)
|
||||
result = rag_chatbot("نوبت کار شبانه برای مادر باردار و دارای فرزند شیرخوار به چه صورت است؟", k = 5)
|
||||
#>>>"So, anarchism is a political philosophy that questions the need for authority and hierarchy, and (...)"
|
||||
print(result)
|
||||
print('end')
|
259
RAG_llama3_multi_embds.py
Normal file
259
RAG_llama3_multi_embds.py
Normal file
|
@ -0,0 +1,259 @@
|
|||
#pip install -q datasets sentence-transformers faiss-cpu accelerate
|
||||
#from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
#from sklearn.metrics.pairwise import cosine_similarity
|
||||
#from sentence_transformers import SentenceTransformer
|
||||
import datetime
|
||||
import os.path
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
from datasets import Dataset, load_from_disk
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
from funcs import read_from_json
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
print('start" ' + str(datetime.datetime.now()))
|
||||
#
|
||||
#----
|
||||
import hazm
|
||||
from cleantext import clean
|
||||
import re
|
||||
|
||||
def cleanhtml(raw_html):
|
||||
cleanr = re.compile('<.*?>')
|
||||
cleantext = re.sub(cleanr, '', raw_html)
|
||||
return cleantext
|
||||
|
||||
normalizer = hazm.Normalizer()
|
||||
# removing wierd patterns
|
||||
wierd_pattern = re.compile("["
|
||||
u"\U0001F600-\U0001F64F" # emoticons
|
||||
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||
u"\U00002702-\U000027B0"
|
||||
u"\U000024C2-\U0001F251"
|
||||
u"\U0001f926-\U0001f937"
|
||||
u'\U00010000-\U0010ffff'
|
||||
u"\u200d"
|
||||
u"\u2640-\u2642"
|
||||
u"\u2600-\u2B55"
|
||||
u"\u23cf"
|
||||
u"\u23e9"
|
||||
u"\u231a"
|
||||
u"\u3030"
|
||||
u"\ufe0f"
|
||||
u"\u2069"
|
||||
u"\u2066"
|
||||
#u"\u200c"
|
||||
u"\u2068"
|
||||
u"\u2067"
|
||||
"]+", flags=re.UNICODE)
|
||||
|
||||
def cleaning(text):
|
||||
text = text.strip()
|
||||
|
||||
# regular cleaning
|
||||
# text = clean(text,
|
||||
# fix_unicode=True,
|
||||
# to_ascii=False,
|
||||
# lower=True,
|
||||
# no_line_breaks=True,
|
||||
# no_urls=True,
|
||||
# no_emails=True,
|
||||
# no_phone_numbers=True,
|
||||
# no_numbers=False,
|
||||
# no_digits=False,
|
||||
# no_currency_symbols=True,
|
||||
# no_punct=False,
|
||||
# replace_with_url="",
|
||||
# replace_with_email="",
|
||||
# replace_with_phone_number="",
|
||||
# replace_with_number="",
|
||||
# replace_with_digit="0",
|
||||
# replace_with_currency_symbol="",
|
||||
# )
|
||||
text = clean(text,
|
||||
extra_spaces = True,
|
||||
lowercase = True
|
||||
)
|
||||
|
||||
# cleaning htmls
|
||||
text = cleanhtml(text)
|
||||
|
||||
# normalizing
|
||||
#normalizer = hazm.Normalizer()
|
||||
text = normalizer.normalize(text)
|
||||
|
||||
|
||||
text = wierd_pattern.sub(r'', text)
|
||||
|
||||
# removing extra spaces, hashtags
|
||||
text = re.sub("#", "", text)
|
||||
text = re.sub("\s+", " ", text)
|
||||
|
||||
return text
|
||||
|
||||
#---
|
||||
#dataset = load_dataset("not-lain/wikipedia")
|
||||
|
||||
# dataset # Let's checkout our dataset
|
||||
# >>> DatasetDict({
|
||||
# train: Dataset({
|
||||
# features: ['id', 'url', 'title', 'text'],
|
||||
# num_rows: 3000
|
||||
# })
|
||||
# })
|
||||
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
|
||||
model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
|
||||
|
||||
if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
|
||||
print('model files is not exists in model path directory.')
|
||||
exit(0)
|
||||
|
||||
# Mean Pooling - Take attention mask into account for correct averaging
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
# Load model from HuggingFace Hub
|
||||
tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
model_bert = AutoModel.from_pretrained(model_name_or_path)
|
||||
|
||||
def encode(sentences):
|
||||
# Tokenize sentences
|
||||
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = model_bert(**encoded_input)
|
||||
|
||||
# Perform pooling. In this case, max pooling.
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
|
||||
#print("Sentence embeddings:")
|
||||
#print(sentence_embeddings)
|
||||
return sentence_embeddings
|
||||
###
|
||||
#----
|
||||
text_arr = []
|
||||
# if os.path.exists('mj/caches/clean_content'):
|
||||
# with open ('mj/caches/clean_content', 'rb') as fp:
|
||||
# text_arr = pickle.load(fp)
|
||||
# else:
|
||||
# print('clean_content pickle file is not exists in mj directory.')
|
||||
# exit(0)
|
||||
|
||||
address = './data/laws_list_170k_embed2.json'
|
||||
text_arr = read_from_json(address)
|
||||
all_sections = []
|
||||
|
||||
for qanon in text_arr:
|
||||
sections = text_arr[qanon]
|
||||
all_sections.extend(sections)
|
||||
|
||||
#---
|
||||
if not os.path.exists('ej/caches_test/embeddings_data'):
|
||||
corpus_embeddings = []
|
||||
for item in tqdm(all_sections):
|
||||
id = item['id']
|
||||
content = item['content']
|
||||
# embedding = encode(content)
|
||||
embedding = torch.tensor([item['embed']])
|
||||
corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
|
||||
data = Dataset.from_list(corpus_embeddings)
|
||||
data.save_to_disk('ej/caches_test/embeddings_data')
|
||||
else:
|
||||
data = load_from_disk('ej/caches_test/embeddings_data')
|
||||
|
||||
def remove_dem(example):
|
||||
example["embeddings"] = example["embeddings"][0]
|
||||
return example
|
||||
|
||||
if os.path.exists('ej/caches_test/embeddings_index.faiss'):
|
||||
data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
|
||||
else:
|
||||
udata = data.map(remove_dem)
|
||||
udata = udata.add_faiss_index("embeddings")
|
||||
udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
|
||||
data = udata
|
||||
|
||||
def search(query: str, k: int = 3 ):
|
||||
"""a function that embeds a new query and returns the most probable results"""
|
||||
embedded_query = encode(query) # embed new query
|
||||
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
|
||||
"embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
|
||||
k=k # get only top k results
|
||||
)
|
||||
return scores, retrieved_examples
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
|
||||
# use quantization to lower GPU usage
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
terminators = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||||
]
|
||||
|
||||
SYS_PROMPT = """You are an assistant for answering questions.
|
||||
You are given the extracted parts of a long document and a question. Provide a conversational answer.
|
||||
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
|
||||
|
||||
|
||||
def format_prompt(prompt,retrieved_documents,k):
|
||||
"""using the retrieved documents we will prompt the model to generate our responses"""
|
||||
PROMPT = f"Question:{prompt}\nContext:"
|
||||
for idx in range(k) :
|
||||
PROMPT+= f"{retrieved_documents['clean_content'][idx]}\n"
|
||||
return PROMPT
|
||||
|
||||
def generate(formatted_prompt):
|
||||
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
|
||||
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
||||
# tell the model to generate
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=1024,
|
||||
eos_token_id=terminators,
|
||||
do_sample=True,
|
||||
temperature=0.01,#0.6
|
||||
top_p=0.9,
|
||||
)
|
||||
response = outputs[0][input_ids.shape[-1]:]
|
||||
return tokenizer.decode(response, skip_special_tokens=True)
|
||||
|
||||
def rag_chatbot(prompt:str,k:int=2):
|
||||
scores , retrieved_documents = search(prompt, k)
|
||||
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
|
||||
for item in retrieved_documents['id']:
|
||||
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
|
||||
return generate(formatted_prompt)
|
||||
result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
|
||||
|
||||
|
||||
print(result)
|
||||
print('end')
|
68
funcs.py
Normal file
68
funcs.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import json
|
||||
|
||||
def read_file_by_address(file_address):
|
||||
with open(file_address, 'a+', encoding='utf-8') as file:
|
||||
text = ''
|
||||
try:
|
||||
# Move the cursor to the beginning of the file to read its content
|
||||
file.seek(0)
|
||||
text = file.read()
|
||||
except:
|
||||
pass
|
||||
return text
|
||||
|
||||
def save_to_file_by_address(file_address, text):
|
||||
with open(file_address, 'a+', encoding='utf-8') as file:
|
||||
previous_result = ''
|
||||
try:
|
||||
previous_result = file.read()
|
||||
except:
|
||||
pass
|
||||
file.write(text)
|
||||
file.close()
|
||||
|
||||
def create_curpos(laws_list:dict):
|
||||
sections_170k = []
|
||||
for law_id, laws_sections in laws_list.items():
|
||||
sections = laws_sections
|
||||
for section in sections:
|
||||
sections_170k.append(section)
|
||||
return sections_170k
|
||||
|
||||
def write_to_json(dict, file_address):
|
||||
|
||||
# تبدیل دیکشنری به فرمت JSON
|
||||
json_data = json.dumps(dict, indent=2, ensure_ascii=False)
|
||||
|
||||
# ذخیره فایل
|
||||
with open(file_address, 'w+', encoding='utf-8') as file:
|
||||
file.write(json_data)
|
||||
|
||||
return True
|
||||
|
||||
def read_from_json(file_address):
|
||||
data_dict = []
|
||||
# خواندن اطلاعات از فایل JSON
|
||||
with open(file_address, 'r', encoding='utf-8') as file:
|
||||
loaded_data = json.load(file)
|
||||
|
||||
# نمایش اطلاعات خوانده شده
|
||||
# for item in loaded_data:
|
||||
# data_dict.append(item)
|
||||
return loaded_data
|
||||
|
||||
def read_dict_from_json(file_address):
|
||||
|
||||
# خواندن اطلاعات از فایل JSON
|
||||
with open(file_address, 'r', encoding='utf-8') as file:
|
||||
loaded_data = json.load(file)
|
||||
|
||||
return loaded_data
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
print('start ... ')
|
||||
laws_list = read_dict_from_json('./data/laws_list_170k_embed2.json')
|
||||
sections_170k = create_curpos(laws_list=laws_list)
|
||||
write_to_json(sections_170k, './data/sections_170k.json')
|
||||
print('finished!')
|
Loading…
Reference in New Issue
Block a user