170 lines
4.4 KiB
Python
Executable File
170 lines
4.4 KiB
Python
Executable File
from fastapi import APIRouter, Request
|
|
from fastapi.responses import JSONResponse
|
|
import time, os, traceback
|
|
from .base_model import Query, LLMOutput, LLMInput, Title
|
|
from .ai_data_parser import AsyncCore
|
|
from .chatbot_handler import (
|
|
InitHybridRetrieverReranker,
|
|
format_answer_bale,
|
|
get_user_prompt2,
|
|
get_user_prompt3,
|
|
load_faiss_index,
|
|
get_in_form,
|
|
|
|
)
|
|
from .static import (
|
|
EMBED_MODEL_PATH,
|
|
FAISS_INDEX_PATH,
|
|
FAISS_METADATA_PATH,
|
|
LLM_URL,
|
|
SYSTEM_PROMPT_FINALL,
|
|
RERANKER_MODEL_PATH,
|
|
LLM_ERROR,
|
|
MODEL_KEY,
|
|
MODEL_NAME,
|
|
OUTPUT_PATH_LLM,
|
|
REASONING_EFFORT,
|
|
TASK_NAME,
|
|
LLM_TIME_OUT,MAX_TOKEN, SYSTEM_PROPMT2
|
|
)
|
|
|
|
|
|
# ################################################## Global-params
|
|
router = APIRouter(tags=["ragchat"])
|
|
# # settings= get_settings()
|
|
|
|
|
|
METADATA_DICT, FAISS_INDEX = load_faiss_index(
|
|
index_path=FAISS_INDEX_PATH, metadata_path=FAISS_METADATA_PATH
|
|
)
|
|
|
|
RAG = InitHybridRetrieverReranker(
|
|
embeder_path=EMBED_MODEL_PATH,
|
|
reranker_path=RERANKER_MODEL_PATH,
|
|
dict_content=METADATA_DICT,
|
|
faiss_index=FAISS_INDEX,
|
|
dense_alpha=0.6,
|
|
device="cuda",
|
|
)
|
|
|
|
RUNNER_PROMPT = AsyncCore(
|
|
model_name=MODEL_NAME,
|
|
api_url=LLM_URL,
|
|
output_path=OUTPUT_PATH_LLM,
|
|
task_name=TASK_NAME,
|
|
output_schema=LLMOutput,
|
|
reasoning_effort=REASONING_EFFORT,
|
|
ai_code_version=MODEL_KEY,
|
|
request_timeout=LLM_TIME_OUT,
|
|
max_token=MAX_TOKEN,
|
|
save_number=1,
|
|
)
|
|
|
|
functions = [
|
|
{
|
|
"name": "legal_answer",
|
|
"description": "خروجی ساختیافته از تحلیل حقوقی با ارجاع کامل به اسناد",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"text": {
|
|
"type": "string",
|
|
"description": "متن کامل پاسخ شامل ارجاع (qsID)"
|
|
},
|
|
"source": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "فهرست شناسه اسناد استفاده شده"
|
|
}
|
|
},
|
|
"required": ["text", "source"]
|
|
}
|
|
}
|
|
]
|
|
|
|
|
|
async def chat_bot_run(query):
|
|
try:
|
|
s = time.time()
|
|
sections_dict = await RAG.search_base(
|
|
query,
|
|
final_k=10,
|
|
topk_dense=100,
|
|
topk_sparse=100,
|
|
pre_rerank_k=100,
|
|
)
|
|
e = time.time()
|
|
input_data = LLMInput(query=query, knowledge=sections_dict)
|
|
# prompt = get_user_prompt2(input_data)
|
|
prompt = get_user_prompt3(query=query, knowledge_json=sections_dict)
|
|
|
|
llm_answer, _ = await RUNNER_PROMPT.single_simple_async_proccess_item(
|
|
item={"user_prompt": prompt, "system_prompt": SYSTEM_PROPMT2},
|
|
functions=functions,
|
|
function_name="legal_answer",
|
|
)
|
|
ee = time.time()
|
|
finall = format_answer_bale(
|
|
answer_text=llm_answer["text"], sources=llm_answer["source"]
|
|
)
|
|
eee = time.time()
|
|
print(
|
|
f'Rag = {e-s}',
|
|
f'llm_answer = {ee-e}',
|
|
f'Form = {eee-ee}',
|
|
sep='\n'
|
|
)
|
|
return finall
|
|
except:
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
async def rag_run(query):
|
|
try:
|
|
s = time.time()
|
|
sections_dict = await RAG.search_base(
|
|
query,
|
|
final_k=10,
|
|
topk_dense=100,
|
|
topk_sparse=100,
|
|
pre_rerank_k=100,
|
|
)
|
|
e = time.time()
|
|
finall = get_in_form(title=query, sections=sections_dict)
|
|
ee = time.time()
|
|
print(
|
|
f'Rag = {e-s}',
|
|
f'Form = {ee-e}',
|
|
sep='\n'
|
|
)
|
|
return finall
|
|
except:
|
|
traceback.print_exc()
|
|
|
|
@router.post("/run_chat")
|
|
async def run_chat(payload: Query, request: Request):
|
|
s = time.time()
|
|
try:
|
|
answer = await chat_bot_run(payload.query)
|
|
except:
|
|
print(f"chat_bot_run FAIL!")
|
|
answer = LLM_ERROR
|
|
e = time.time()
|
|
print(f"Total Time {e-s:.2f}'s")
|
|
return JSONResponse({"result": answer}, status_code=201)
|
|
|
|
|
|
|
|
@router.post("/run_rag")
|
|
async def run_chat(payload: Query, request: Request):
|
|
s = time.time()
|
|
try:
|
|
answer = await rag_run(payload.query)
|
|
except:
|
|
print(f"chat_bot_run FAIL!")
|
|
answer = LLM_ERROR
|
|
e = time.time()
|
|
print(f"Total Time {e-s:.2f}'s")
|
|
return JSONResponse({"result": answer}, status_code=201)
|