260 lines
8.3 KiB
Python
260 lines
8.3 KiB
Python
#pip install -q datasets sentence-transformers faiss-cpu accelerate
|
|
#from sentence_transformers import SentenceTransformer, util
|
|
|
|
from tqdm import tqdm
|
|
import torch
|
|
#from sklearn.metrics.pairwise import cosine_similarity
|
|
#from sentence_transformers import SentenceTransformer
|
|
import datetime
|
|
import os.path
|
|
from transformers import AutoTokenizer
|
|
import os
|
|
from datasets import Dataset, load_from_disk
|
|
import os
|
|
from transformers import AutoTokenizer, AutoModel
|
|
import torch
|
|
from funcs import read_from_json
|
|
|
|
start_time = datetime.datetime.now()
|
|
print('start" ' + str(datetime.datetime.now()))
|
|
#
|
|
#----
|
|
import hazm
|
|
from cleantext import clean
|
|
import re
|
|
|
|
def cleanhtml(raw_html):
|
|
cleanr = re.compile('<.*?>')
|
|
cleantext = re.sub(cleanr, '', raw_html)
|
|
return cleantext
|
|
|
|
normalizer = hazm.Normalizer()
|
|
# removing wierd patterns
|
|
wierd_pattern = re.compile("["
|
|
u"\U0001F600-\U0001F64F" # emoticons
|
|
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
|
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
|
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
|
u"\U00002702-\U000027B0"
|
|
u"\U000024C2-\U0001F251"
|
|
u"\U0001f926-\U0001f937"
|
|
u'\U00010000-\U0010ffff'
|
|
u"\u200d"
|
|
u"\u2640-\u2642"
|
|
u"\u2600-\u2B55"
|
|
u"\u23cf"
|
|
u"\u23e9"
|
|
u"\u231a"
|
|
u"\u3030"
|
|
u"\ufe0f"
|
|
u"\u2069"
|
|
u"\u2066"
|
|
#u"\u200c"
|
|
u"\u2068"
|
|
u"\u2067"
|
|
"]+", flags=re.UNICODE)
|
|
|
|
def cleaning(text):
|
|
text = text.strip()
|
|
|
|
# regular cleaning
|
|
# text = clean(text,
|
|
# fix_unicode=True,
|
|
# to_ascii=False,
|
|
# lower=True,
|
|
# no_line_breaks=True,
|
|
# no_urls=True,
|
|
# no_emails=True,
|
|
# no_phone_numbers=True,
|
|
# no_numbers=False,
|
|
# no_digits=False,
|
|
# no_currency_symbols=True,
|
|
# no_punct=False,
|
|
# replace_with_url="",
|
|
# replace_with_email="",
|
|
# replace_with_phone_number="",
|
|
# replace_with_number="",
|
|
# replace_with_digit="0",
|
|
# replace_with_currency_symbol="",
|
|
# )
|
|
text = clean(text,
|
|
extra_spaces = True,
|
|
lowercase = True
|
|
)
|
|
|
|
# cleaning htmls
|
|
text = cleanhtml(text)
|
|
|
|
# normalizing
|
|
#normalizer = hazm.Normalizer()
|
|
text = normalizer.normalize(text)
|
|
|
|
|
|
text = wierd_pattern.sub(r'', text)
|
|
|
|
# removing extra spaces, hashtags
|
|
text = re.sub("#", "", text)
|
|
text = re.sub("\s+", " ", text)
|
|
|
|
return text
|
|
|
|
#---
|
|
#dataset = load_dataset("not-lain/wikipedia")
|
|
|
|
# dataset # Let's checkout our dataset
|
|
# >>> DatasetDict({
|
|
# train: Dataset({
|
|
# features: ['id', 'url', 'title', 'text'],
|
|
# num_rows: 3000
|
|
# })
|
|
# })
|
|
#model_name_or_path = "mixedbread-ai/mxbai-embed-large-v1"
|
|
model_name_or_path = "/home/gpu/NLP/MLM/CODES/BERT/finetune/MODELS/roberta-fa-zwnj-base-law-2-pt"
|
|
|
|
if not os.path.exists(model_name_or_path+'/model.safetensors') or not os.path.exists(model_name_or_path+'/tokenizer.json'):
|
|
print('model files is not exists in model path directory.')
|
|
exit(0)
|
|
|
|
# Mean Pooling - Take attention mask into account for correct averaging
|
|
def mean_pooling(model_output, attention_mask):
|
|
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
# Load model from HuggingFace Hub
|
|
tokenizer_bert = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
model_bert = AutoModel.from_pretrained(model_name_or_path)
|
|
|
|
def encode(sentences):
|
|
# Tokenize sentences
|
|
encoded_input = tokenizer_bert(sentences, padding=True, truncation=True, return_tensors='pt')
|
|
|
|
# Compute token embeddings
|
|
with torch.no_grad():
|
|
model_output = model_bert(**encoded_input)
|
|
|
|
# Perform pooling. In this case, max pooling.
|
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
|
|
#print("Sentence embeddings:")
|
|
#print(sentence_embeddings)
|
|
return sentence_embeddings
|
|
###
|
|
#----
|
|
text_arr = []
|
|
# if os.path.exists('mj/caches/clean_content'):
|
|
# with open ('mj/caches/clean_content', 'rb') as fp:
|
|
# text_arr = pickle.load(fp)
|
|
# else:
|
|
# print('clean_content pickle file is not exists in mj directory.')
|
|
# exit(0)
|
|
|
|
address = './data/laws_list_170k_embed2.json'
|
|
text_arr = read_from_json(address)
|
|
all_sections = []
|
|
|
|
for qanon in text_arr:
|
|
sections = text_arr[qanon]
|
|
all_sections.extend(sections)
|
|
|
|
#---
|
|
if not os.path.exists('ej/caches_test/embeddings_data'):
|
|
corpus_embeddings = []
|
|
for item in tqdm(all_sections):
|
|
id = item['id']
|
|
content = item['content']
|
|
# embedding = encode(content)
|
|
embedding = torch.tensor([item['embed']])
|
|
corpus_embeddings.append({'embeddings':embedding, 'clean_content': content, 'id': id})
|
|
data = Dataset.from_list(corpus_embeddings)
|
|
data.save_to_disk('ej/caches_test/embeddings_data')
|
|
else:
|
|
data = load_from_disk('ej/caches_test/embeddings_data')
|
|
|
|
def remove_dem(example):
|
|
example["embeddings"] = example["embeddings"][0]
|
|
return example
|
|
|
|
if os.path.exists('ej/caches_test/embeddings_index.faiss'):
|
|
data.load_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
|
|
else:
|
|
udata = data.map(remove_dem)
|
|
udata = udata.add_faiss_index("embeddings")
|
|
udata.save_faiss_index('embeddings', 'ej/caches_test/embeddings_index.faiss')
|
|
data = udata
|
|
|
|
def search(query: str, k: int = 3 ):
|
|
"""a function that embeds a new query and returns the most probable results"""
|
|
embedded_query = encode(query) # embed new query
|
|
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
|
|
"embeddings", embedded_query.numpy(), # compare our new embedded query with the dataset embeddings
|
|
k=k # get only top k results
|
|
)
|
|
return scores, retrieved_examples
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
import torch
|
|
|
|
#model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
model_id = "PartAI/Dorna-Llama3-8B-Instruct"
|
|
# use quantization to lower GPU usage
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
quantization_config=bnb_config
|
|
)
|
|
terminators = [
|
|
tokenizer.eos_token_id,
|
|
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
|
]
|
|
|
|
SYS_PROMPT = """You are an assistant for answering questions.
|
|
You are given the extracted parts of a long document and a question. Provide a conversational answer.
|
|
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
|
|
|
|
|
|
def format_prompt(prompt,retrieved_documents,k):
|
|
"""using the retrieved documents we will prompt the model to generate our responses"""
|
|
PROMPT = f"Question:{prompt}\nContext:"
|
|
for idx in range(k) :
|
|
PROMPT+= f"{retrieved_documents['clean_content'][idx]}\n"
|
|
return PROMPT
|
|
|
|
def generate(formatted_prompt):
|
|
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
|
|
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
|
# tell the model to generate
|
|
input_ids = tokenizer.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=True,
|
|
return_tensors="pt"
|
|
).to(model.device)
|
|
outputs = model.generate(
|
|
input_ids,
|
|
max_new_tokens=1024,
|
|
eos_token_id=terminators,
|
|
do_sample=True,
|
|
temperature=0.01,#0.6
|
|
top_p=0.9,
|
|
)
|
|
response = outputs[0][input_ids.shape[-1]:]
|
|
return tokenizer.decode(response, skip_special_tokens=True)
|
|
|
|
def rag_chatbot(prompt:str,k:int=2):
|
|
scores , retrieved_documents = search(prompt, k)
|
|
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
|
|
for item in retrieved_documents['id']:
|
|
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
|
|
return generate(formatted_prompt)
|
|
result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
|
|
|
|
|
|
print(result)
|
|
print('end')
|