272 lines
9.0 KiB
Python
272 lines
9.0 KiB
Python
#pip install -q datasets sentence-transformers faiss-cpu accelerate
|
||
#from sentence_transformers import SentenceTransformer, util
|
||
import json
|
||
from tqdm import tqdm
|
||
import torch
|
||
#from sklearn.metrics.pairwise import cosine_similarity
|
||
#from sentence_transformers import SentenceTransformer
|
||
import numpy as np
|
||
import time
|
||
import os.path
|
||
# import pickle
|
||
from transformers import AutoTokenizer
|
||
import os
|
||
from datasets import Dataset, load_from_disk
|
||
import tensorflow as tf
|
||
from transformers import AutoTokenizer, AutoModel
|
||
|
||
from funcs import read_from_json, read_dict_from_json
|
||
|
||
print('start')
|
||
|
||
start_time = time.time()
|
||
#
|
||
#----
|
||
import hazm
|
||
from cleantext import clean
|
||
import re
|
||
|
||
def cleanhtml(raw_html):
|
||
cleanr = re.compile('<.*?>')
|
||
cleantext = re.sub(cleanr, '', raw_html)
|
||
return cleantext
|
||
|
||
normalizer = hazm.Normalizer()
|
||
# removing wierd patterns
|
||
wierd_pattern = re.compile("["
|
||
u"\U0001F600-\U0001F64F" # emoticons
|
||
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||
u"\U00002702-\U000027B0"
|
||
u"\U000024C2-\U0001F251"
|
||
u"\U0001f926-\U0001f937"
|
||
u'\U00010000-\U0010ffff'
|
||
u"\u200d"
|
||
u"\u2640-\u2642"
|
||
u"\u2600-\u2B55"
|
||
u"\u23cf"
|
||
u"\u23e9"
|
||
u"\u231a"
|
||
u"\u3030"
|
||
u"\ufe0f"
|
||
u"\u2069"
|
||
u"\u2066"
|
||
#u"\u200c"
|
||
u"\u2068"
|
||
u"\u2067"
|
||
"]+", flags=re.UNICODE)
|
||
|
||
def cleaning(text):
|
||
text = text.strip()
|
||
|
||
# regular cleaning
|
||
# text = clean(text,
|
||
# fix_unicode=True,
|
||
# to_ascii=False,
|
||
# lower=True,
|
||
# no_line_breaks=True,
|
||
# no_urls=True,
|
||
# no_emails=True,
|
||
# no_phone_numbers=True,
|
||
# no_numbers=False,
|
||
# no_digits=False,
|
||
# no_currency_symbols=True,
|
||
# no_punct=False,
|
||
# replace_with_url="",
|
||
# replace_with_email="",
|
||
# replace_with_phone_number="",
|
||
# replace_with_number="",
|
||
# replace_with_digit="0",
|
||
# replace_with_currency_symbol="",
|
||
# )
|
||
text = clean(text,
|
||
extra_spaces = True,
|
||
lowercase = True
|
||
)
|
||
|
||
# cleaning htmls
|
||
text = cleanhtml(text)
|
||
|
||
# normalizing
|
||
#normalizer = hazm.Normalizer()
|
||
text = normalizer.normalize(text)
|
||
|
||
|
||
text = wierd_pattern.sub(r'', text)
|
||
|
||
# removing extra spaces, hashtags
|
||
text = re.sub("#", "", text)
|
||
text = re.sub("\s+", " ", text)
|
||
|
||
return text
|
||
|
||
#---
|
||
#dataset = load_dataset("not-lain/wikipedia")
|
||
|
||
# dataset # Let's checkout our dataset
|
||
# >>> DatasetDict({
|
||
# train: Dataset({
|
||
# features: ['id', 'url', 'title', 'text'],
|
||
# num_rows: 3000
|
||
# })
|
||
# })
|
||
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
|
||
model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
|
||
|
||
if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
|
||
print('model files is not exists in model path directory.')
|
||
exit(0)
|
||
|
||
# Mean Pooling - Take attention mask into account for correct averaging
|
||
def mean_pooling(model_output, attention_mask):
|
||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||
|
||
# Load model from HuggingFace Hub
|
||
tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
|
||
model_bert = AutoModel.from_pretrained(model_name_or_path)
|
||
|
||
def encode(sentences):
|
||
# Tokenize sentences
|
||
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
|
||
|
||
# Compute token embeddings
|
||
with torch.no_grad():
|
||
model_output = model_bert(**encoded_input)
|
||
|
||
# Perform pooling. In this case, max pooling.
|
||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||
|
||
#print("Sentence embeddings:")
|
||
#print(sentence_embeddings)
|
||
return sentence_embeddings
|
||
###
|
||
#----
|
||
text_arr = []
|
||
all_sections = []
|
||
"""
|
||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||
چون بر اساس اولین خوانش امبدینگ سکشن ها فایل مخزن برای جستجو ایجاد شده است ، نیاز به خوانش مجدد نیست(به همین دلیل حلقه زیر کامنت شده است) مگر اینکه امبدینگ جدیدی تولید شود
|
||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"""
|
||
# for i in range(1,111):#111
|
||
# file_address = f"./data/mj_qa_section_ai/mj_qa_section_ai{i}.json"
|
||
# sections = read_from_json(file_address)
|
||
# all_sections.extend(sections)
|
||
# print(file_address)
|
||
|
||
|
||
#---
|
||
if not os.path.exists('mj/caches_test/embeddings_data'):
|
||
corpus_embeddings = []
|
||
for item in tqdm(all_sections):
|
||
id = item['_id']
|
||
source = item['_source']
|
||
content = source['content']
|
||
# embedding1 = encode(content)
|
||
embedding = torch.tensor([source['embeddings']])
|
||
# embedding = torch.tensor(source['embeddings']).reshape(-1)
|
||
|
||
corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
|
||
data = Dataset.from_list(corpus_embeddings)
|
||
data.save_to_disk('mj/caches_test/embeddings_data')
|
||
else:
|
||
data = load_from_disk('mj/caches_test/embeddings_data')
|
||
|
||
|
||
def remove_dem(example):
|
||
example["embeddings"] = example["embeddings"][0]
|
||
return example
|
||
|
||
if os.path.exists('mj/caches_test/embeddings_index.faiss'):
|
||
data.load_faiss_index('embeddings', 'mj/caches_test/embeddings_index.faiss')
|
||
else:
|
||
udata = data.map(remove_dem)
|
||
udata = udata.add_faiss_index("embeddings")
|
||
udata.save_faiss_index('embeddings', 'mj/caches_test/embeddings_index.faiss')
|
||
data = udata
|
||
|
||
def search(query: str, k: int = 3 ):
|
||
"""a function that embeds a new query and returns the most probable results"""
|
||
embedded_query = encode(query) # embed new query
|
||
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
|
||
"embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
|
||
k=k # get only top k results
|
||
)
|
||
return scores, retrieved_examples
|
||
|
||
|
||
# search for word anarchy and get the best 4 matching values from the dataset
|
||
# scores , result = search("اداره کار و رفاه", 4 )
|
||
# print(result['clean_content'])
|
||
|
||
|
||
|
||
|
||
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||
import torch
|
||
|
||
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
|
||
# use quantization to lower GPU usage
|
||
bnb_config = BitsAndBytesConfig(
|
||
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
||
)
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch.bfloat16,
|
||
device_map="auto",
|
||
quantization_config=bnb_config
|
||
)
|
||
terminators = [
|
||
tokenizer.eos_token_id,
|
||
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||
]
|
||
|
||
SYS_PROMPT = """You are an assistant for answering questions.
|
||
You are given the extracted parts of a long document and a question. Provide a conversational answer.
|
||
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
|
||
|
||
|
||
def format_prompt(prompt,retrieved_documents,k):
|
||
"""using the retrieved documents we will prompt the model to generate our responses"""
|
||
PROMPT = f"Question:{prompt}\nContext:"
|
||
for idx in range(k) :
|
||
PROMPT+= f"{retrieved_documents['clean_content'][idx]}\n"
|
||
return PROMPT
|
||
|
||
def generate(formatted_prompt):
|
||
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
|
||
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
||
# tell the model to generate
|
||
input_ids = tokenizer.apply_chat_template(
|
||
messages,
|
||
add_generation_prompt=True,
|
||
return_tensors="pt"
|
||
).to(model.device)
|
||
outputs = model.generate(
|
||
input_ids,
|
||
max_new_tokens=1024,
|
||
eos_token_id=terminators,
|
||
do_sample=True,
|
||
temperature=0.6,
|
||
top_p=0.9,
|
||
)
|
||
response = outputs[0][input_ids.shape[-1]:]
|
||
return tokenizer.decode(response, skip_special_tokens=True)
|
||
|
||
def rag_chatbot(prompt:str,k:int=2):
|
||
scores , retrieved_documents = search(prompt, k)
|
||
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
|
||
for item in retrieved_documents['id']:
|
||
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
|
||
return generate(formatted_prompt)
|
||
result = rag_chatbot("نوبت کار شبانه برای مادر باردار و دارای فرزند شیرخوار به چه صورت است؟", k = 5)
|
||
#>>>"So, anarchism is a political philosophy that questions the need for authority and hierarchy, and (...)"
|
||
print(result)
|
||
print('end')
|