rag_api codes

This commit is contained in:
ajokar 2025-03-05 16:05:44 +03:30
commit 2f74923291
2 changed files with 340 additions and 0 deletions

68
funcs.py Normal file
View File

@ -0,0 +1,68 @@
import json
def read_file_by_address(file_address):
with open(file_address, 'a+', encoding='utf-8') as file:
text = ''
try:
# Move the cursor to the beginning of the file to read its content
file.seek(0)
text = file.read()
except:
pass
return text
def save_to_file_by_address(file_address, text):
with open(file_address, 'a+', encoding='utf-8') as file:
previous_result = ''
try:
previous_result = file.read()
except:
pass
file.write(text)
file.close()
def create_curpos(laws_list:dict):
sections_170k = []
for law_id, laws_sections in laws_list.items():
sections = laws_sections
for section in sections:
sections_170k.append(section)
return sections_170k
def write_to_json(dict, file_address):
# تبدیل دیکشنری به فرمت JSON
json_data = json.dumps(dict, indent=2, ensure_ascii=False)
# ذخیره فایل
with open(file_address, 'w+', encoding='utf-8') as file:
file.write(json_data)
return True
def read_from_json(file_address):
data_dict = []
# خواندن اطلاعات از فایل JSON
with open(file_address, 'r', encoding='utf-8') as file:
loaded_data = json.load(file)
# نمایش اطلاعات خوانده شده
# for item in loaded_data:
# data_dict.append(item)
return loaded_data
def read_dict_from_json(file_address):
# خواندن اطلاعات از فایل JSON
with open(file_address, 'r', encoding='utf-8') as file:
loaded_data = json.load(file)
return loaded_data
if __name__ == "__main__":
print('start ... ')
laws_list = read_dict_from_json('./data/laws_list_170k_embed2.json')
sections_170k = create_curpos(laws_list=laws_list)
write_to_json(sections_170k, './data/sections_170k.json')
print('finished!')

272
main.py Normal file
View File

@ -0,0 +1,272 @@
import datetime
from fastapi import FastAPI
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import hazm
from cleantext import clean
import re
from pydantic import BaseModel
start_time = datetime.datetime.now()
print('start: ' + str(start_time))
#
#----
def cleanhtml(raw_html):
cleanr = re.compile('<.*?>')
cleantext = re.sub(cleanr, '', raw_html)
return cleantext
normalizer = hazm.Normalizer()
# removing wierd patterns
wierd_pattern = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
u"\U0001f926-\U0001f937"
u'\U00010000-\U0010ffff'
u"\u200d"
u"\u2640-\u2642"
u"\u2600-\u2B55"
u"\u23cf"
u"\u23e9"
u"\u231a"
u"\u3030"
u"\ufe0f"
u"\u2069"
u"\u2066"
#u"\u200c"
u"\u2068"
u"\u2067"
"]+", flags=re.UNICODE)
def cleaning(text):
text = text.strip()
# regular cleaning
# text = clean(text,
# fix_unicode=True,
# to_ascii=False,
# lower=True,
# no_line_breaks=True,
# no_urls=True,
# no_emails=True,
# no_phone_numbers=True,
# no_numbers=False,
# no_digits=False,
# no_currency_symbols=True,
# no_punct=False,
# replace_with_url="",
# replace_with_email="",
# replace_with_phone_number="",
# replace_with_number="",
# replace_with_digit="0",
# replace_with_currency_symbol="",
# )
text = clean(text,
extra_spaces = True,
lowercase = True
)
# cleaning htmls
text = cleanhtml(text)
# normalizing
#normalizer = hazm.Normalizer()
text = normalizer.normalize(text)
text = wierd_pattern.sub(r'', text)
# removing extra spaces, hashtags
text = re.sub("#", "", text)
text = re.sub("\s+", " ", text)
return text
#---
#dataset = load_dataset("not-lain/wikipedia")
# dataset # Let's checkout our dataset
# >>> DatasetDict({
# train: Dataset({
# features: ['id', 'url', 'title', 'text'],
# num_rows: 3000
# })
# })
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
# model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
# if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
# print('model files is not exists in model path directory.')
# exit(0)
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load model from HuggingFace Hub
#tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
#model_bert = AutoModel.from_pretrained(model_name_or_path)
def encode(sentences):
# Tokenize sentences
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model_bert(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#print("Sentence embeddings:")
#print(sentence_embeddings)
return sentence_embeddings
###
#----
# text_arr = []
# if os.path.exists('mj/caches/clean_content'):
# with open ('mj/caches/clean_content', 'rb') as fp:
# text_arr = pickle.load(fp)
# else:
# print('clean_content pickle file is not exists in mj directory.')
# exit(0)
# address = './data/laws_list_170k_embed2.json'
# # text_arr = read_from_json(address)
# all_sections = []
# for qanon in text_arr:
# sections = text_arr[qanon]
# all_sections.extend(sections)
#---
# if not os.path.exists('ej/caches_test/embeddings_data'):
# corpus_embeddings = []
# for item in tqdm(all_sections):
# id = item['id']
# content = item['content']
# # embedding = encode(content)
# embedding = torch.tensor([item['embed']])
# corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
# data = Dataset.from_list(corpus_embeddings)
# data.save_to_disk('ej/caches_test/embeddings_data')
# else:
# data = load_from_disk('ej/caches_test/embeddings_data')
# def remove_dem(example):
# example["embeddings"] = example["embeddings"][0]
# return example
# if os.path.exists('ej/caches_test/embeddings_index.faiss'):
# data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
# else:
# udata = data.map(remove_dem)
# udata = udata.add_faiss_index("embeddings")
# udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
# data = udata
# def search(query: str, k: int = 3 ):
# """a function that embeds a new query and returns the most probable results"""
# embedded_query = encode(query) # embed new query
# scores, retrieved_examples = data.get_nearest_examples( # retrieve results
# "embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
# k=k # get only top k results
# )
# return scores, retrieved_examples
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
def format_prompt(prompt, retrieved_documents, k) -> str:
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents[idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=terminators,
do_sample=True,
temperature=0.01,#0.6
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
for item in retrieved_documents['id']:
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
return generate(formatted_prompt)
#result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
class TextList(BaseModel):
text: list[str]
app = FastAPI()
@app.post("/rag_chat/")
async def rag_chat(prompt: str, input_texts: TextList):
try:
input_texts= []
prompt = ''
formatted_prompt = format_prompt(prompt, input_texts,len(input_texts))
except Exception as e:
return {"status": "error", "message": str(e)}
try:
result = generate(formatted_prompt)
return {"status": "success", "result": result}
except Exception as e:
return {"status": "error", "message": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="192.168.23.51", port=8000)
# async def rag_chat(prompt: str, input_texts: Union[str, List[str]]):