from fastapi import APIRouter, Request from fastapi.responses import JSONResponse import time, os, traceback from .base_model import Query, LLMOutput, LLMInput, Title from .ai_data_parser import AsyncCore from .chatbot_handler import ( InitHybridRetrieverReranker, format_answer_bale, get_user_prompt2, get_user_prompt3, load_faiss_index, get_in_form, ) from .static import ( EMBED_MODEL_PATH, FAISS_INDEX_PATH, FAISS_METADATA_PATH, LLM_URL, SYSTEM_PROMPT_FINALL, RERANKER_MODEL_PATH, LLM_ERROR, MODEL_KEY, MODEL_NAME, OUTPUT_PATH_LLM, REASONING_EFFORT, TASK_NAME, LLM_TIME_OUT,MAX_TOKEN, SYSTEM_PROPMT2 ) # ################################################## Global-params router = APIRouter(tags=["ragchat"]) # # settings= get_settings() METADATA_DICT, FAISS_INDEX = load_faiss_index( index_path=FAISS_INDEX_PATH, metadata_path=FAISS_METADATA_PATH ) RAG = InitHybridRetrieverReranker( embeder_path=EMBED_MODEL_PATH, reranker_path=RERANKER_MODEL_PATH, dict_content=METADATA_DICT, faiss_index=FAISS_INDEX, dense_alpha=0.6, device="cuda", ) RUNNER_PROMPT = AsyncCore( model_name=MODEL_NAME, api_url=LLM_URL, output_path=OUTPUT_PATH_LLM, task_name=TASK_NAME, output_schema=LLMOutput, reasoning_effort=REASONING_EFFORT, ai_code_version=MODEL_KEY, request_timeout=LLM_TIME_OUT, max_token=MAX_TOKEN, save_number=1, ) functions = [ { "name": "legal_answer", "description": "خروجی ساخت‌یافته از تحلیل حقوقی با ارجاع کامل به اسناد", "parameters": { "type": "object", "properties": { "text": { "type": "string", "description": "متن کامل پاسخ شامل ارجاع (qsID)" }, "source": { "type": "array", "items": {"type": "string"}, "description": "فهرست شناسه اسناد استفاده شده" } }, "required": ["text", "source"] } } ] async def chat_bot_run(query): try: s = time.time() sections_dict = await RAG.search_base( query, final_k=10, topk_dense=100, topk_sparse=100, pre_rerank_k=100, ) e = time.time() input_data = LLMInput(query=query, knowledge=sections_dict) # prompt = get_user_prompt2(input_data) prompt = get_user_prompt3(query=query, knowledge_json=sections_dict) llm_answer, _ = await RUNNER_PROMPT.single_simple_async_proccess_item( item={"user_prompt": prompt, "system_prompt": SYSTEM_PROPMT2}, functions=functions, function_name="legal_answer", ) ee = time.time() finall = format_answer_bale( answer_text=llm_answer["text"], sources=llm_answer["source"] ) eee = time.time() print( f'Rag = {e-s}', f'llm_answer = {ee-e}', f'Form = {eee-ee}', sep='\n' ) return finall except: traceback.print_exc() async def rag_run(query): try: s = time.time() sections_dict = await RAG.search_base( query, final_k=10, topk_dense=100, topk_sparse=100, pre_rerank_k=100, ) e = time.time() finall = get_in_form(title=query, sections=sections_dict) ee = time.time() print( f'Rag = {e-s}', f'Form = {ee-e}', sep='\n' ) return finall except: traceback.print_exc() @router.post("/run_chat") async def run_chat(payload: Query, request: Request): s = time.time() try: answer = await chat_bot_run(payload.query) except: print(f"chat_bot_run FAIL!") answer = LLM_ERROR e = time.time() print(f"Total Time {e-s:.2f}'s") return JSONResponse({"result": answer}, status_code=201) @router.post("/run_rag") async def run_chat(payload: Query, request: Request): s = time.time() try: answer = await rag_run(payload.query) except: print(f"chat_bot_run FAIL!") answer = LLM_ERROR e = time.time() print(f"Total Time {e-s:.2f}'s") return JSONResponse({"result": answer}, status_code=201)