rag_qavanin_api/routers/rag_base.py
2025-11-29 20:11:27 +00:00

170 lines
4.4 KiB
Python
Executable File

from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
import time, os, traceback
from .base_model import Query, LLMOutput, LLMInput, Title
from .ai_data_parser import AsyncCore
from .chatbot_handler import (
InitHybridRetrieverReranker,
format_answer_bale,
get_user_prompt2,
get_user_prompt3,
load_faiss_index,
get_in_form,
)
from .static import (
EMBED_MODEL_PATH,
FAISS_INDEX_PATH,
FAISS_METADATA_PATH,
LLM_URL,
SYSTEM_PROMPT_FINALL,
RERANKER_MODEL_PATH,
LLM_ERROR,
MODEL_KEY,
MODEL_NAME,
OUTPUT_PATH_LLM,
REASONING_EFFORT,
TASK_NAME,
LLM_TIME_OUT,MAX_TOKEN, SYSTEM_PROPMT2
)
# ################################################## Global-params
router = APIRouter(tags=["ragchat"])
# # settings= get_settings()
METADATA_DICT, FAISS_INDEX = load_faiss_index(
index_path=FAISS_INDEX_PATH, metadata_path=FAISS_METADATA_PATH
)
RAG = InitHybridRetrieverReranker(
embeder_path=EMBED_MODEL_PATH,
reranker_path=RERANKER_MODEL_PATH,
dict_content=METADATA_DICT,
faiss_index=FAISS_INDEX,
dense_alpha=0.6,
device="cuda",
)
RUNNER_PROMPT = AsyncCore(
model_name=MODEL_NAME,
api_url=LLM_URL,
output_path=OUTPUT_PATH_LLM,
task_name=TASK_NAME,
output_schema=LLMOutput,
reasoning_effort=REASONING_EFFORT,
ai_code_version=MODEL_KEY,
request_timeout=LLM_TIME_OUT,
max_token=MAX_TOKEN,
save_number=1,
)
functions = [
{
"name": "legal_answer",
"description": "خروجی ساخت‌یافته از تحلیل حقوقی با ارجاع کامل به اسناد",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "متن کامل پاسخ شامل ارجاع (qsID)"
},
"source": {
"type": "array",
"items": {"type": "string"},
"description": "فهرست شناسه اسناد استفاده شده"
}
},
"required": ["text", "source"]
}
}
]
async def chat_bot_run(query):
try:
s = time.time()
sections_dict = await RAG.search_base(
query,
final_k=10,
topk_dense=100,
topk_sparse=100,
pre_rerank_k=100,
)
e = time.time()
input_data = LLMInput(query=query, knowledge=sections_dict)
# prompt = get_user_prompt2(input_data)
prompt = get_user_prompt3(query=query, knowledge_json=sections_dict)
llm_answer, _ = await RUNNER_PROMPT.single_simple_async_proccess_item(
item={"user_prompt": prompt, "system_prompt": SYSTEM_PROPMT2},
functions=functions,
function_name="legal_answer",
)
ee = time.time()
finall = format_answer_bale(
answer_text=llm_answer["text"], sources=llm_answer["source"]
)
eee = time.time()
print(
f'Rag = {e-s}',
f'llm_answer = {ee-e}',
f'Form = {eee-ee}',
sep='\n'
)
return finall
except:
traceback.print_exc()
async def rag_run(query):
try:
s = time.time()
sections_dict = await RAG.search_base(
query,
final_k=10,
topk_dense=100,
topk_sparse=100,
pre_rerank_k=100,
)
e = time.time()
finall = get_in_form(title=query, sections=sections_dict)
ee = time.time()
print(
f'Rag = {e-s}',
f'Form = {ee-e}',
sep='\n'
)
return finall
except:
traceback.print_exc()
@router.post("/run_chat")
async def run_chat(payload: Query, request: Request):
s = time.time()
try:
answer = await chat_bot_run(payload.query)
except:
print(f"chat_bot_run FAIL!")
answer = LLM_ERROR
e = time.time()
print(f"Total Time {e-s:.2f}'s")
return JSONResponse({"result": answer}, status_code=201)
@router.post("/run_rag")
async def run_chat(payload: Query, request: Request):
s = time.time()
try:
answer = await rag_run(payload.query)
except:
print(f"chat_bot_run FAIL!")
answer = LLM_ERROR
e = time.time()
print(f"Total Time {e-s:.2f}'s")
return JSONResponse({"result": answer}, status_code=201)