rag/RAG_llama3_multi_embds.py
2025-03-05 17:32:39 +03:30

260 lines
8.3 KiB
Python

#pip install -q datasets sentence-transformers faiss-cpu accelerate
#from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
import torch
#from sklearn.metrics.pairwise import cosine_similarity
#from sentence_transformers import SentenceTransformer
import datetime
import os.path
from transformers import AutoTokenizer
import os
from datasets import Dataset, load_from_disk
import os
from transformers import AutoTokenizer, AutoModel
import torch
from funcs import read_from_json
start_time = datetime.datetime.now()
print('start" ' + str(datetime.datetime.now()))
#
#----
import hazm
from cleantext import clean
import re
def cleanhtml(raw_html):
cleanr = re.compile('<.*?>')
cleantext = re.sub(cleanr, '', raw_html)
return cleantext
normalizer = hazm.Normalizer()
# removing wierd patterns
wierd_pattern = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
u"\U0001f926-\U0001f937"
u'\U00010000-\U0010ffff'
u"\u200d"
u"\u2640-\u2642"
u"\u2600-\u2B55"
u"\u23cf"
u"\u23e9"
u"\u231a"
u"\u3030"
u"\ufe0f"
u"\u2069"
u"\u2066"
#u"\u200c"
u"\u2068"
u"\u2067"
"]+", flags=re.UNICODE)
def cleaning(text):
text = text.strip()
# regular cleaning
# text = clean(text,
# fix_unicode=True,
# to_ascii=False,
# lower=True,
# no_line_breaks=True,
# no_urls=True,
# no_emails=True,
# no_phone_numbers=True,
# no_numbers=False,
# no_digits=False,
# no_currency_symbols=True,
# no_punct=False,
# replace_with_url="",
# replace_with_email="",
# replace_with_phone_number="",
# replace_with_number="",
# replace_with_digit="0",
# replace_with_currency_symbol="",
# )
text = clean(text,
extra_spaces = True,
lowercase = True
)
# cleaning htmls
text = cleanhtml(text)
# normalizing
#normalizer = hazm.Normalizer()
text = normalizer.normalize(text)
text = wierd_pattern.sub(r'', text)
# removing extra spaces, hashtags
text = re.sub("#", "", text)
text = re.sub("\s+", " ", text)
return text
#---
#dataset = load_dataset("not-lain/wikipedia")
# dataset # Let's checkout our dataset
# >>> DatasetDict({
# train: Dataset({
# features: ['id', 'url', 'title', 'text'],
# num_rows: 3000
# })
# })
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
print('model files is not exists in model path directory.')
exit(0)
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load model from HuggingFace Hub
tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
model_bert = AutoModel.from_pretrained(model_name_or_path)
def encode(sentences):
# Tokenize sentences
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model_bert(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#print("Sentence embeddings:")
#print(sentence_embeddings)
return sentence_embeddings
###
#----
text_arr = []
# if os.path.exists('mj/caches/clean_content'):
# with open ('mj/caches/clean_content', 'rb') as fp:
# text_arr = pickle.load(fp)
# else:
# print('clean_content pickle file is not exists in mj directory.')
# exit(0)
address = './data/laws_list_170k_embed2.json'
text_arr = read_from_json(address)
all_sections = []
for qanon in text_arr:
sections = text_arr[qanon]
all_sections.extend(sections)
#---
if not os.path.exists('ej/caches_test/embeddings_data'):
corpus_embeddings = []
for item in tqdm(all_sections):
id = item['id']
content = item['content']
# embedding = encode(content)
embedding = torch.tensor([item['embed']])
corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
data = Dataset.from_list(corpus_embeddings)
data.save_to_disk('ej/caches_test/embeddings_data')
else:
data = load_from_disk('ej/caches_test/embeddings_data')
def remove_dem(example):
example["embeddings"] = example["embeddings"][0]
return example
if os.path.exists('ej/caches_test/embeddings_index.faiss'):
data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
else:
udata = data.map(remove_dem)
udata = udata.add_faiss_index("embeddings")
udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
data = udata
def search(query: str, k: int = 3 ):
"""a function that embeds a new query and returns the most probable results"""
embedded_query = encode(query) # embed new query
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
"embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
k=k # get only top k results
)
return scores, retrieved_examples
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
def format_prompt(prompt,retrieved_documents,k):
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents['clean_content'][idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=terminators,
do_sample=True,
temperature=0.01,#0.6
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
for item in retrieved_documents['id']:
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
return generate(formatted_prompt)
result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
print(result)
print('end')