import datetime from fastapi import FastAPI import torch from transformers import AutoTokenizer from transformers import AutoTokenizer, AutoModel import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import torch import hazm from cleantext import clean import re from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware start_time = datetime.datetime.now() print('start: ' + str(start_time)) # #---- def cleanhtml(raw_html): cleanr = re.compile('<.*?>') cleantext = re.sub(cleanr, '', raw_html) return cleantext normalizer = hazm.Normalizer() # removing wierd patterns wierd_pattern = re.compile("[" u"\U0001F600-\U0001F64F" # emoticons u"\U0001F300-\U0001F5FF" # symbols & pictographs u"\U0001F680-\U0001F6FF" # transport & map symbols u"\U0001F1E0-\U0001F1FF" # flags (iOS) u"\U00002702-\U000027B0" u"\U000024C2-\U0001F251" u"\U0001f926-\U0001f937" u'\U00010000-\U0010ffff' u"\u200d" u"\u2640-\u2642" u"\u2600-\u2B55" u"\u23cf" u"\u23e9" u"\u231a" u"\u3030" u"\ufe0f" u"\u2069" u"\u2066" #u"\u200c" u"\u2068" u"\u2067" "]+", flags=re.UNICODE) def cleaning(text): text = text.strip() # regular cleaning # text = clean(text, # fix_unicode=True, # to_ascii=False, # lower=True, # no_line_breaks=True, # no_urls=True, # no_emails=True, # no_phone_numbers=True, # no_numbers=False, # no_digits=False, # no_currency_symbols=True, # no_punct=False, # replace_with_url="", # replace_with_email="", # replace_with_phone_number="", # replace_with_number="", # replace_with_digit="0", # replace_with_currency_symbol="", # ) text = clean(text, extra_spaces = True, lowercase = True ) # cleaning htmls text = cleanhtml(text) # normalizing #normalizer = hazm.Normalizer() text = normalizer.normalize(text) text = wierd_pattern.sub(r'', text) # removing extra spaces, hashtags text = re.sub("#", "", text) text = re.sub("\s+", " ", text) return text #--- #dataset = load_dataset("not-lain/wikipedia") # dataset # Let's checkout our dataset # >>> DatasetDict({ # train: Dataset({ # features: ['id', 'url', 'title', 'text'], # num_rows: 3000 # }) # }) #model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1" # model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt" # if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'): # print('model files is not exists in model path directory.') # exit(0) # Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) # Load model from HuggingFace Hub #tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path) #model_bert = AutoModel.from_pretrained(model_name_or_path) def encode(sentences): # Tokenize sentences encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings with torch.no_grad(): model_output = model_bert(**encoded_input) # Perform pooling. In this case, max pooling. sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) #print("Sentence embeddings:") #print(sentence_embeddings) return sentence_embeddings ### #---- # text_arr = [] # if os.path.exists('mj/caches/clean_content'): # with open ('mj/caches/clean_content', 'rb') as fp: # text_arr = pickle.load(fp) # else: # print('clean_content pickle file is not exists in mj directory.') # exit(0) # address = './data/laws_list_170k_embed2.json' # # text_arr = read_from_json(address) # all_sections = [] # for qanon in text_arr: # sections = text_arr[qanon] # all_sections.extend(sections) #--- # if not os.path.exists('ej/caches_test/embeddings_data'): # corpus_embeddings = [] # for item in tqdm(all_sections): # id = item['id'] # content = item['content'] # # embedding = encode(content) # embedding = torch.tensor([item['embed']]) # corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id}) # data = Dataset.from_list(corpus_embeddings) # data.save_to_disk('ej/caches_test/embeddings_data') # else: # data = load_from_disk('ej/caches_test/embeddings_data') # def remove_dem(example): # example["embeddings"] = example["embeddings"][0] # return example # if os.path.exists('ej/caches_test/embeddings_index.faiss'): # data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss') # else: # udata = data.map(remove_dem) # udata = udata.add_faiss_index("embeddings") # udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss') # data = udata # def search(query: str, k: int = 3 ): # """a function that embeds a new query and returns the most probable results""" # embedded_query = encode(query) # embed new query # scores, retrieved_examples = data.get_nearest_examples( # retrieve results # "embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings # k=k # get only top k results # ) # return scores, retrieved_examples #model_id = "meta-llama/Meta-Llama-3-8B-Instruct" model_id = "PartAI/Dorna-Llama3-8B-Instruct" # use quantization to lower GPU usage bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=bnb_config ) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] SYS_PROMPT = """You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد""" def format_prompt_0(prompt, retrieved_documents, k) -> str: """using the retrieved documents we will prompt the model to generate our responses""" PROMPT = f"Question:{prompt}\nContext:" for idx in range(k) : PROMPT+= f"{retrieved_documents[idx]}\n" return PROMPT def format_prompt(prompt, text) -> str: """using the retrieved documents we will prompt the model to generate our responses""" PROMPT = f"Question:{prompt}\nContext:" PROMPT+= f"{text}\n" return PROMPT def generate(formatted_prompt): formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}] # tell the model to generate input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) outputs = model.generate( input_ids, max_new_tokens=1024, eos_token_id=terminators, do_sample=True, temperature=0.01,#0.6 top_p=0.9, ) response = outputs[0][input_ids.shape[-1]:] return tokenizer.decode(response, skip_special_tokens=True) def rag_chatbot(prompt:str,k:int=2): scores , retrieved_documents = search(prompt, k) formatted_prompt = format_prompt_0(prompt,retrieved_documents,k) for item in retrieved_documents['id']: print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item) return generate(formatted_prompt) #result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5) class TextList(BaseModel): query: str text: str app = FastAPI() # origins = ['*'] # app.add_middleware( # CORSMiddleware, # allow_origins=origins, # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) @app.post("/rag_chat") async def rag_chat(input: TextList): try: formatted_prompt = format_prompt(input.query, input.text) except Exception as e: return {"status": "error", "message": str(e)} try: result = generate(formatted_prompt) return {"status": "success", "result": result} except Exception as e: return {"status": "error", "message": str(e)} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="192.168.23.51", port=8000) # async def rag_chat(prompt: str, input_texts: Union[str, List[str]]):