rag/RAG_llama3_chatbot_elastic_embd.py
2025-03-05 17:32:39 +03:30

272 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#pip install -q datasets sentence-transformers faiss-cpu accelerate
#from sentence_transformers import SentenceTransformer, util
import json
from tqdm import tqdm
import torch
#from sklearn.metrics.pairwise import cosine_similarity
#from sentence_transformers import SentenceTransformer
import numpy as np
import time
import os.path
# import pickle
from transformers import AutoTokenizer
import os
from datasets import Dataset, load_from_disk
import tensorflow as tf
from transformers import AutoTokenizer, AutoModel
from funcs import read_from_json, read_dict_from_json
print('start')
start_time = time.time()
#
#----
import hazm
from cleantext import clean
import re
def cleanhtml(raw_html):
cleanr = re.compile('<.*?>')
cleantext = re.sub(cleanr, '', raw_html)
return cleantext
normalizer = hazm.Normalizer()
# removing wierd patterns
wierd_pattern = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
u"\U0001f926-\U0001f937"
u'\U00010000-\U0010ffff'
u"\u200d"
u"\u2640-\u2642"
u"\u2600-\u2B55"
u"\u23cf"
u"\u23e9"
u"\u231a"
u"\u3030"
u"\ufe0f"
u"\u2069"
u"\u2066"
#u"\u200c"
u"\u2068"
u"\u2067"
"]+", flags=re.UNICODE)
def cleaning(text):
text = text.strip()
# regular cleaning
# text = clean(text,
# fix_unicode=True,
# to_ascii=False,
# lower=True,
# no_line_breaks=True,
# no_urls=True,
# no_emails=True,
# no_phone_numbers=True,
# no_numbers=False,
# no_digits=False,
# no_currency_symbols=True,
# no_punct=False,
# replace_with_url="",
# replace_with_email="",
# replace_with_phone_number="",
# replace_with_number="",
# replace_with_digit="0",
# replace_with_currency_symbol="",
# )
text = clean(text,
extra_spaces = True,
lowercase = True
)
# cleaning htmls
text = cleanhtml(text)
# normalizing
#normalizer = hazm.Normalizer()
text = normalizer.normalize(text)
text = wierd_pattern.sub(r'', text)
# removing extra spaces, hashtags
text = re.sub("#", "", text)
text = re.sub("\s+", " ", text)
return text
#---
#dataset = load_dataset("not-lain/wikipedia")
# dataset # Let's checkout our dataset
# >>> DatasetDict({
# train: Dataset({
# features: ['id', 'url', 'title', 'text'],
# num_rows: 3000
# })
# })
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
print('model files is not exists in model path directory.')
exit(0)
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load model from HuggingFace Hub
tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
model_bert = AutoModel.from_pretrained(model_name_or_path)
def encode(sentences):
# Tokenize sentences
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model_bert(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#print("Sentence embeddings:")
#print(sentence_embeddings)
return sentence_embeddings
###
#----
text_arr = []
all_sections = []
"""
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
چون بر اساس اولین خوانش امبدینگ سکشن ها فایل مخزن برای جستجو ایجاد شده است ، نیاز به خوانش مجدد نیست(به همین دلیل حلقه زیر کامنت شده است) مگر اینکه امبدینگ جدیدی تولید شود
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"""
# for i in range(1,111):#111
# file_address = f"./data/mj_qa_section_ai/mj_qa_section_ai{i}.json"
# sections = read_from_json(file_address)
# all_sections.extend(sections)
# print(file_address)
#---
if not os.path.exists('mj/caches_test/embeddings_data'):
corpus_embeddings = []
for item in tqdm(all_sections):
id = item['_id']
source = item['_source']
content = source['content']
# embedding1 = encode(content)
embedding = torch.tensor([source['embeddings']])
# embedding = torch.tensor(source['embeddings']).reshape(-1)
corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
data = Dataset.from_list(corpus_embeddings)
data.save_to_disk('mj/caches_test/embeddings_data')
else:
data = load_from_disk('mj/caches_test/embeddings_data')
def remove_dem(example):
example["embeddings"] = example["embeddings"][0]
return example
if os.path.exists('mj/caches_test/embeddings_index.faiss'):
data.load_faiss_index('embeddings', 'mj/caches_test/embeddings_index.faiss')
else:
udata = data.map(remove_dem)
udata = udata.add_faiss_index("embeddings")
udata.save_faiss_index('embeddings', 'mj/caches_test/embeddings_index.faiss')
data = udata
def search(query: str, k: int = 3 ):
"""a function that embeds a new query and returns the most probable results"""
embedded_query = encode(query) # embed new query
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
"embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
k=k # get only top k results
)
return scores, retrieved_examples
# search for word anarchy and get the best 4 matching values from the dataset
# scores , result = search("اداره کار و رفاه", 4 )
# print(result['clean_content'])
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
def format_prompt(prompt,retrieved_documents,k):
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents['clean_content'][idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
for item in retrieved_documents['id']:
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
return generate(formatted_prompt)
result = rag_chatbot("نوبت کار شبانه برای مادر باردار و دارای فرزند شیرخوار به چه صورت است؟", k = 5)
#>>>"So, anarchism is a political philosophy that questions the need for authority and hierarchy, and (...)"
print(result)
print('end')