add avalai api to chat with llms
This commit is contained in:
parent
4b517ce00e
commit
8ba2c2038f
117
oil_domain_chatbot_avalai_05.py
Normal file
117
oil_domain_chatbot_avalai_05.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
import json
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel # for pytorch
|
||||
from transformers import TFAutoModelForTokenClassification # for tensorflow
|
||||
from transformers import pipeline
|
||||
import os
|
||||
from datasets import Dataset, load_from_disk
|
||||
import g4f.Provider
|
||||
from g4f.client import Client
|
||||
import g4f
|
||||
|
||||
from langchain_openai import ChatOpenAI # pip install -U langchain_openai
|
||||
|
||||
from openai import OpenAI # pip install -U openai
|
||||
|
||||
|
||||
client = OpenAI(
|
||||
base_url="https://api.avalai.ir/v1",
|
||||
api_key="aa-..."
|
||||
)
|
||||
|
||||
print('start')
|
||||
#---
|
||||
# NOTE: for bug in dumping float in json
|
||||
class NumpyFloatValuesEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.float32):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
#json.dumps(d, cls=NumpyFloatValuesEncoder)
|
||||
#----
|
||||
|
||||
#---
|
||||
|
||||
def create_prompt(query, related_articles):
|
||||
result = f"""
|
||||
Given a query (which is delimited with triple backticks) and the related articles (which is also delimited with triple backticks). For the given query, find abrogated articles from the related articles if existes.
|
||||
To answer, please use the following format:
|
||||
Step-by-step reasoning: <your step-by-step reasoning>
|
||||
Answer: <a clear response from the related articles>
|
||||
Query: ```{query}```
|
||||
Related articles: ```{related_articles}```
|
||||
|
||||
"""
|
||||
return result
|
||||
|
||||
|
||||
model_name_rag = "gpt-4o"
|
||||
# llm = ChatOpenAI(
|
||||
# model=model_name_rag,
|
||||
# base_url="https://api.avalai.ir/v1",
|
||||
# api_key="aa-"
|
||||
# )
|
||||
|
||||
|
||||
def run_prompt(user_prompt, model: str ):
|
||||
try:
|
||||
|
||||
#out = llm.invoke(user_prompt)
|
||||
#return out.content
|
||||
#####
|
||||
system_prompt = "You are a legal assistant and will carefully review user requests and respond according to the prompts given.Please produce the answers in Persian."
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
#model="gpt-4o-mini",
|
||||
#model="command-r",
|
||||
#model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=1024
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e}"
|
||||
|
||||
|
||||
related_data = []
|
||||
print('loading data')
|
||||
|
||||
model_name_or_path = "reranked_10_oil_similar_20_oil_BAAI_bge-m3"
|
||||
#model_name_or_path = "reranked_10_oil_similar_20_oil_jinaai_jina-embeddings-v3"
|
||||
content_file = open(f'./mj/{model_name_or_path}.json', "r", encoding='utf-8')
|
||||
oil_data = json.load(content_file)
|
||||
for qan in tqdm(oil_data):
|
||||
rule_id = qan["rule_id"]
|
||||
rule = qan["rule"]
|
||||
retrieved_rules = qan["retrieved_rules"]
|
||||
related_articles = ""
|
||||
for relateds in retrieved_rules:
|
||||
r_rule = relateds[1]
|
||||
r_rule_id = relateds[2]
|
||||
related_articles += f"Article ({r_rule_id}): {r_rule}. \n"
|
||||
prompt = create_prompt(rule, related_articles)
|
||||
result = run_prompt(prompt, model_name_rag)
|
||||
qan["result"] = result
|
||||
content_file.close()
|
||||
|
||||
##########
|
||||
filename = "./mj/final_{}_oil_{}.json".format(model_name_rag, model_name_or_path)
|
||||
similarity_file = open(filename, "w", encoding='utf-8')
|
||||
similarity_file.write(json.dumps(oil_data, ensure_ascii=False, cls=NumpyFloatValuesEncoder, indent=4))#
|
||||
similarity_file.close()
|
||||
print('end')
|
Loading…
Reference in New Issue
Block a user