117 lines
3.6 KiB
Python
117 lines
3.6 KiB
Python
import json
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import torch
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoModel # for pytorch
|
|
from transformers import TFAutoModelForTokenClassification # for tensorflow
|
|
from transformers import pipeline
|
|
import os
|
|
from datasets import Dataset, load_from_disk
|
|
import g4f.Provider
|
|
from g4f.client import Client
|
|
import g4f
|
|
|
|
from langchain_openai import ChatOpenAI # pip install -U langchain_openai
|
|
|
|
from openai import OpenAI # pip install -U openai
|
|
|
|
|
|
client = OpenAI(
|
|
base_url="https://api.avalai.ir/v1",
|
|
api_key="aa-..."
|
|
)
|
|
|
|
print('start')
|
|
#---
|
|
# NOTE: for bug in dumping float in json
|
|
class NumpyFloatValuesEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, np.float32):
|
|
return float(obj)
|
|
if isinstance(obj, np.integer):
|
|
return int(obj)
|
|
if isinstance(obj, np.floating):
|
|
return float(obj)
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
return json.JSONEncoder.default(self, obj)
|
|
#json.dumps(d, cls=NumpyFloatValuesEncoder)
|
|
#----
|
|
|
|
#---
|
|
|
|
def create_prompt(query, related_articles):
|
|
result = f"""
|
|
Given a query (which is delimited with triple backticks) and the related articles (which is also delimited with triple backticks). For the given query, find abrogated articles from the related articles if existes.
|
|
To answer, please use the following format:
|
|
Step-by-step reasoning: <your step-by-step reasoning>
|
|
Answer: <a clear response from the related articles>
|
|
Query: ```{query}```
|
|
Related articles: ```{related_articles}```
|
|
|
|
"""
|
|
return result
|
|
|
|
|
|
model_name_rag = "gpt-4o"
|
|
# llm = ChatOpenAI(
|
|
# model=model_name_rag,
|
|
# base_url="https://api.avalai.ir/v1",
|
|
# api_key="aa-"
|
|
# )
|
|
|
|
|
|
def run_prompt(user_prompt, model: str ):
|
|
try:
|
|
|
|
#out = llm.invoke(user_prompt)
|
|
#return out.content
|
|
#####
|
|
system_prompt = "You are a legal assistant and will carefully review user requests and respond according to the prompts given.Please produce the answers in Persian."
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
#model="gpt-4o-mini",
|
|
#model="command-r",
|
|
#model="gpt-3.5-turbo",
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt}
|
|
],
|
|
temperature=0.3,
|
|
max_tokens=1024
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
except Exception as e:
|
|
return f"An error occurred: {e}"
|
|
|
|
|
|
related_data = []
|
|
print('loading data')
|
|
|
|
model_name_or_path = "reranked_10_oil_similar_20_oil_BAAI_bge-m3"
|
|
#model_name_or_path = "reranked_10_oil_similar_20_oil_jinaai_jina-embeddings-v3"
|
|
content_file = open(f'./mj/{model_name_or_path}.json', "r", encoding='utf-8')
|
|
oil_data = json.load(content_file)
|
|
for qan in tqdm(oil_data):
|
|
rule_id = qan["rule_id"]
|
|
rule = qan["rule"]
|
|
retrieved_rules = qan["retrieved_rules"]
|
|
related_articles = ""
|
|
for relateds in retrieved_rules:
|
|
r_rule = relateds[1]
|
|
r_rule_id = relateds[2]
|
|
related_articles += f"Article ({r_rule_id}): {r_rule}. \n"
|
|
prompt = create_prompt(rule, related_articles)
|
|
result = run_prompt(prompt, model_name_rag)
|
|
qan["result"] = result
|
|
content_file.close()
|
|
|
|
##########
|
|
filename = "./mj/final_{}_oil_{}.json".format(model_name_rag, model_name_or_path)
|
|
similarity_file = open(filename, "w", encoding='utf-8')
|
|
similarity_file.write(json.dumps(oil_data, ensure_ascii=False, cls=NumpyFloatValuesEncoder, indent=4))#
|
|
similarity_file.close()
|
|
print('end') |