From 2f74923291039de1db6d85293c479dc16623dfda Mon Sep 17 00:00:00 2001 From: ajokar Date: Wed, 5 Mar 2025 16:05:44 +0330 Subject: [PATCH] rag_api codes --- funcs.py | 68 ++++++++++++++ main.py | 272 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 340 insertions(+) create mode 100644 funcs.py create mode 100644 main.py diff --git a/funcs.py b/funcs.py new file mode 100644 index 0000000..f26377a --- /dev/null +++ b/funcs.py @@ -0,0 +1,68 @@ +import json + +def read_file_by_address(file_address): + with open(file_address, 'a+', encoding='utf-8') as file: + text = '' + try: + # Move the cursor to the beginning of the file to read its content + file.seek(0) + text = file.read() + except: + pass + return text + +def save_to_file_by_address(file_address, text): + with open(file_address, 'a+', encoding='utf-8') as file: + previous_result = '' + try: + previous_result = file.read() + except: + pass + file.write(text) + file.close() + +def create_curpos(laws_list:dict): + sections_170k = [] + for law_id, laws_sections in laws_list.items(): + sections = laws_sections + for section in sections: + sections_170k.append(section) + return sections_170k + +def write_to_json(dict, file_address): + + # تبدیل دیکشنری به فرمت JSON + json_data = json.dumps(dict, indent=2, ensure_ascii=False) + + # ذخیره فایل + with open(file_address, 'w+', encoding='utf-8') as file: + file.write(json_data) + + return True + +def read_from_json(file_address): + data_dict = [] + # خواندن اطلاعات از فایل JSON + with open(file_address, 'r', encoding='utf-8') as file: + loaded_data = json.load(file) + + # نمایش اطلاعات خوانده شده + # for item in loaded_data: + # data_dict.append(item) + return loaded_data + +def read_dict_from_json(file_address): + + # خواندن اطلاعات از فایل JSON + with open(file_address, 'r', encoding='utf-8') as file: + loaded_data = json.load(file) + + return loaded_data + +if __name__ == "__main__": + + print('start ... ') + laws_list = read_dict_from_json('./data/laws_list_170k_embed2.json') + sections_170k = create_curpos(laws_list=laws_list) + write_to_json(sections_170k, './data/sections_170k.json') + print('finished!') \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..c0a6d6a --- /dev/null +++ b/main.py @@ -0,0 +1,272 @@ +import datetime +from fastapi import FastAPI +import torch +from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModel +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig +import torch +import hazm +from cleantext import clean +import re +from pydantic import BaseModel + +start_time = datetime.datetime.now() +print('start: ' + str(start_time)) +# +#---- + +def cleanhtml(raw_html): + cleanr = re.compile('<.*?>') + cleantext = re.sub(cleanr, '', raw_html) + return cleantext + +normalizer = hazm.Normalizer() +# removing wierd patterns +wierd_pattern = re.compile("[" + u"\U0001F600-\U0001F64F" # emoticons + u"\U0001F300-\U0001F5FF" # symbols & pictographs + u"\U0001F680-\U0001F6FF" # transport & map symbols + u"\U0001F1E0-\U0001F1FF" # flags (iOS) + u"\U00002702-\U000027B0" + u"\U000024C2-\U0001F251" + u"\U0001f926-\U0001f937" + u'\U00010000-\U0010ffff' + u"\u200d" + u"\u2640-\u2642" + u"\u2600-\u2B55" + u"\u23cf" + u"\u23e9" + u"\u231a" + u"\u3030" + u"\ufe0f" + u"\u2069" + u"\u2066" + #u"\u200c" + u"\u2068" + u"\u2067" + "]+", flags=re.UNICODE) + +def cleaning(text): + text = text.strip() + + # regular cleaning + # text = clean(text, + # fix_unicode=True, + # to_ascii=False, + # lower=True, + # no_line_breaks=True, + # no_urls=True, + # no_emails=True, + # no_phone_numbers=True, + # no_numbers=False, + # no_digits=False, + # no_currency_symbols=True, + # no_punct=False, + # replace_with_url="", + # replace_with_email="", + # replace_with_phone_number="", + # replace_with_number="", + # replace_with_digit="0", + # replace_with_currency_symbol="", + # ) + text = clean(text, + extra_spaces = True, + lowercase = True + ) + + # cleaning htmls + text = cleanhtml(text) + + # normalizing + #normalizer = hazm.Normalizer() + text = normalizer.normalize(text) + + + text = wierd_pattern.sub(r'', text) + + # removing extra spaces, hashtags + text = re.sub("#", "", text) + text = re.sub("\s+", " ", text) + + return text + +#--- +#dataset = load_dataset("not-lain/wikipedia") + +# dataset # Let's checkout our dataset +# >>> DatasetDict({ +# train: Dataset({ +# features: ['id', 'url', 'title', 'text'], +# num_rows: 3000 +# }) +# }) +#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1" +# model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt" + +# if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'): +# print('model files is not exists in model path directory.') +# exit(0) + +# Mean Pooling - Take attention mask into account for correct averaging +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +# Load model from HuggingFace Hub +#tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path) +#model_bert = AutoModel.from_pretrained(model_name_or_path) + +def encode(sentences): + # Tokenize sentences + encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt') + + # Compute token embeddings + with torch.no_grad(): + model_output = model_bert(**encoded_input) + + # Perform pooling. In this case, max pooling. + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + + #print("Sentence embeddings:") + #print(sentence_embeddings) + return sentence_embeddings +### +#---- +# text_arr = [] +# if os.path.exists('mj/caches/clean_content'): +# with open ('mj/caches/clean_content', 'rb') as fp: +# text_arr = pickle.load(fp) +# else: +# print('clean_content pickle file is not exists in mj directory.') +# exit(0) + +# address = './data/laws_list_170k_embed2.json' +# # text_arr = read_from_json(address) +# all_sections = [] + +# for qanon in text_arr: +# sections = text_arr[qanon] +# all_sections.extend(sections) + +#--- +# if not os.path.exists('ej/caches_test/embeddings_data'): +# corpus_embeddings = [] +# for item in tqdm(all_sections): +# id = item['id'] +# content = item['content'] +# # embedding = encode(content) +# embedding = torch.tensor([item['embed']]) +# corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id}) +# data = Dataset.from_list(corpus_embeddings) +# data.save_to_disk('ej/caches_test/embeddings_data') +# else: +# data = load_from_disk('ej/caches_test/embeddings_data') + +# def remove_dem(example): +# example["embeddings"] = example["embeddings"][0] +# return example + +# if os.path.exists('ej/caches_test/embeddings_index.faiss'): +# data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss') +# else: +# udata = data.map(remove_dem) +# udata = udata.add_faiss_index("embeddings") +# udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss') +# data = udata + +# def search(query: str, k: int = 3 ): +# """a function that embeds a new query and returns the most probable results""" +# embedded_query = encode(query) # embed new query +# scores, retrieved_examples = data.get_nearest_examples( # retrieve results +# "embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings +# k=k # get only top k results +# ) +# return scores, retrieved_examples + + + +#model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +model_id = "PartAI/Dorna-Llama3-8B-Instruct" +# use quantization to lower GPU usage +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 +) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + quantization_config=bnb_config +) +terminators = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>") +] + +SYS_PROMPT = """You are an assistant for answering questions. +You are given the extracted parts of a long document and a question. Provide a conversational answer. +If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد""" + + +def format_prompt(prompt, retrieved_documents, k) -> str: + """using the retrieved documents we will prompt the model to generate our responses""" + PROMPT = f"Question:{prompt}\nContext:" + for idx in range(k) : + PROMPT+= f"{retrieved_documents[idx]}\n" + return PROMPT + +def generate(formatted_prompt): + formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM + messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}] + # tell the model to generate + input_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt" + ).to(model.device) + outputs = model.generate( + input_ids, + max_new_tokens=1024, + eos_token_id=terminators, + do_sample=True, + temperature=0.01,#0.6 + top_p=0.9, + ) + response = outputs[0][input_ids.shape[-1]:] + return tokenizer.decode(response, skip_special_tokens=True) + +def rag_chatbot(prompt:str,k:int=2): + scores , retrieved_documents = search(prompt, k) + formatted_prompt = format_prompt(prompt,retrieved_documents,k) + for item in retrieved_documents['id']: + print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item) + return generate(formatted_prompt) +#result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5) + +class TextList(BaseModel): + text: list[str] + +app = FastAPI() +@app.post("/rag_chat/") +async def rag_chat(prompt: str, input_texts: TextList): + + try: + input_texts= [] + prompt = '' + formatted_prompt = format_prompt(prompt, input_texts,len(input_texts)) + except Exception as e: + return {"status": "error", "message": str(e)} + + try: + result = generate(formatted_prompt) + return {"status": "success", "result": result} + except Exception as e: + return {"status": "error", "message": str(e)} + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="192.168.23.51", port=8000) +# async def rag_chat(prompt: str, input_texts: Union[str, List[str]]): \ No newline at end of file