rag_api/main.py

290 lines
9.1 KiB
Python

import datetime
from fastapi import FastAPI
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import hazm
from cleantext import clean
import re
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
start_time = datetime.datetime.now()
print('start: ' + str(start_time))
#
#----
def cleanhtml(raw_html):
cleanr = re.compile('<.*?>')
cleantext = re.sub(cleanr, '', raw_html)
return cleantext
normalizer = hazm.Normalizer()
# removing wierd patterns
wierd_pattern = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
u"\U0001f926-\U0001f937"
u'\U00010000-\U0010ffff'
u"\u200d"
u"\u2640-\u2642"
u"\u2600-\u2B55"
u"\u23cf"
u"\u23e9"
u"\u231a"
u"\u3030"
u"\ufe0f"
u"\u2069"
u"\u2066"
#u"\u200c"
u"\u2068"
u"\u2067"
"]+", flags=re.UNICODE)
def cleaning(text):
text = text.strip()
# regular cleaning
# text = clean(text,
# fix_unicode=True,
# to_ascii=False,
# lower=True,
# no_line_breaks=True,
# no_urls=True,
# no_emails=True,
# no_phone_numbers=True,
# no_numbers=False,
# no_digits=False,
# no_currency_symbols=True,
# no_punct=False,
# replace_with_url="",
# replace_with_email="",
# replace_with_phone_number="",
# replace_with_number="",
# replace_with_digit="0",
# replace_with_currency_symbol="",
# )
text = clean(text,
extra_spaces = True,
lowercase = True
)
# cleaning htmls
text = cleanhtml(text)
# normalizing
#normalizer = hazm.Normalizer()
text = normalizer.normalize(text)
text = wierd_pattern.sub(r'', text)
# removing extra spaces, hashtags
text = re.sub("#", "", text)
text = re.sub("\s+", " ", text)
return text
#---
#dataset = load_dataset("not-lain/wikipedia")
# dataset # Let's checkout our dataset
# >>> DatasetDict({
# train: Dataset({
# features: ['id', 'url', 'title', 'text'],
# num_rows: 3000
# })
# })
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
# model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
# if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
# print('model files is not exists in model path directory.')
# exit(0)
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load model from HuggingFace Hub
#tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
#model_bert = AutoModel.from_pretrained(model_name_or_path)
def encode(sentences):
# Tokenize sentences
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model_bert(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#print("Sentence embeddings:")
#print(sentence_embeddings)
return sentence_embeddings
###
#----
# text_arr = []
# if os.path.exists('mj/caches/clean_content'):
# with open ('mj/caches/clean_content', 'rb') as fp:
# text_arr = pickle.load(fp)
# else:
# print('clean_content pickle file is not exists in mj directory.')
# exit(0)
# address = './data/laws_list_170k_embed2.json'
# # text_arr = read_from_json(address)
# all_sections = []
# for qanon in text_arr:
# sections = text_arr[qanon]
# all_sections.extend(sections)
#---
# if not os.path.exists('ej/caches_test/embeddings_data'):
# corpus_embeddings = []
# for item in tqdm(all_sections):
# id = item['id']
# content = item['content']
# # embedding = encode(content)
# embedding = torch.tensor([item['embed']])
# corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
# data = Dataset.from_list(corpus_embeddings)
# data.save_to_disk('ej/caches_test/embeddings_data')
# else:
# data = load_from_disk('ej/caches_test/embeddings_data')
# def remove_dem(example):
# example["embeddings"] = example["embeddings"][0]
# return example
# if os.path.exists('ej/caches_test/embeddings_index.faiss'):
# data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
# else:
# udata = data.map(remove_dem)
# udata = udata.add_faiss_index("embeddings")
# udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
# data = udata
# def search(query: str, k: int = 3 ):
# """a function that embeds a new query and returns the most probable results"""
# embedded_query = encode(query) # embed new query
# scores, retrieved_examples = data.get_nearest_examples( # retrieve results
# "embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
# k=k # get only top k results
# )
# return scores, retrieved_examples
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
def format_prompt_0(prompt, retrieved_documents, k) -> str:
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents[idx]}\n"
return PROMPT
def format_prompt(prompt, text) -> str:
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
PROMPT+= f"{text}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=terminators,
do_sample=True,
temperature=0.01,#0.6
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt_0(prompt,retrieved_documents,k)
for item in retrieved_documents['id']:
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
return generate(formatted_prompt)
#result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
class TextList(BaseModel):
query: str
text: str
app = FastAPI()
# origins = ['*']
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
@app.post("/rag_chat")
async def rag_chat(input: TextList):
try:
formatted_prompt = format_prompt(input.query, input.text)
except Exception as e:
return {"status": "error", "message": str(e)}
try:
result = generate(formatted_prompt)
return {"status": "success", "result": result}
except Exception as e:
return {"status": "error", "message": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="192.168.23.51", port=8000)
# async def rag_chat(prompt: str, input_texts: Union[str, List[str]]):