176 lines
6.3 KiB
Python
176 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
|
|
from app.core.map_index_reader import MapIndexReader
|
|
from app.core.elastic_query_builder import ElasticQueryBuilder
|
|
from app.core.field_processor import FieldProcessor
|
|
from app.core.response_helper import ResponseHelper
|
|
from app.routes.v1.models import (
|
|
SearchRequest,
|
|
InsertRequest,
|
|
UpdateByQueryRequest,
|
|
DeleteByQueryRequest,
|
|
)
|
|
from app.routes.rag.models import (
|
|
RagInsertRequest,
|
|
)
|
|
from typing import Any, Dict, List, Optional
|
|
import time
|
|
from app.routes.v1.elastic import (
|
|
insert,
|
|
search,
|
|
)
|
|
import uuid
|
|
import requests
|
|
from datetime import datetime
|
|
from app.config.settings import get_settings, Settings
|
|
|
|
|
|
router = APIRouter(tags=["ragchat"])
|
|
settings= get_settings()
|
|
|
|
|
|
def get_elastic_helper(request: Request):
|
|
helper = getattr(request.app.state, "elastic_helper", None)
|
|
if helper is None:
|
|
raise RuntimeError("Elasticsearch helper not initialized")
|
|
return helper
|
|
|
|
|
|
@router.post("/{type_name}/credit_refresh")
|
|
async def credit_refresh(type_name: str, request: Request):
|
|
try:
|
|
print("credit_refresh ...->", settings.ai_rag_host)
|
|
if settings.ai_rag_host:
|
|
url = settings.ai_rag_host + "/" + "credit_refresh"
|
|
body = {}
|
|
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
|
response = requests.request("POST", url, headers=headers, json=body)
|
|
return response.text.replace('"', '')
|
|
except Exception as exc: # noqa: BLE001
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
|
|
@router.post("/{type_name}/insert")
|
|
async def insert_rag(type_name: str, payload: RagInsertRequest, request: Request):
|
|
|
|
time_stamp = int(datetime.now().timestamp())
|
|
user_id = 0
|
|
document = {
|
|
'id' : id,
|
|
'title' : title,
|
|
'chat_id' : chat_id,
|
|
'user_id' : user_id,
|
|
'bale_info' : {
|
|
"user_name" : "",
|
|
"first_name" : "",
|
|
"last_name" : ""
|
|
},
|
|
'user_query' : payload.user_query,
|
|
'model_key' : answer.get("model_key", ""),
|
|
'retrived_passage' : answer.get("retrived_passage", ""),
|
|
'retrived_ref_ids' : answer.get("retrived_ref_ids", []),
|
|
'retrived_duration' : int(answer.get("retrived_duration", 0)),
|
|
'prompt_type' : answer.get("prompt_type", ""),
|
|
'llm_duration' : int(answer.get("llm_duration", 0)),
|
|
'full_duration' : int(answer.get("full_duration", 0)),
|
|
'time_create' : time_stamp,
|
|
'used_ref_ids' : answer.get("used_ref_ids", []),
|
|
'prompt_answer' : answer.get("prompt_answer", ""),
|
|
'status_text' : answer.get("status_text", ""),
|
|
'status' : answer.get("status", False)
|
|
|
|
}
|
|
insertRequest = InsertRequest(id=id, document=document)
|
|
response = await insert(type_name, insertRequest, request )
|
|
response['answer'] = document
|
|
|
|
|
|
@router.post("/{type_name}/query")
|
|
async def query_rag(type_name: str, payload: RagInsertRequest, request: Request):
|
|
#
|
|
# uid = uuid.uuid4().hex[:8] # فقط ۸ کاراکتر از uuid eg : '9f1c2a7b'
|
|
id = uuid.uuid4().hex[:8]
|
|
print("insert_rag start ... ", id)
|
|
chat_id = payload.chat_id
|
|
if not chat_id :
|
|
chat_id = uuid.uuid4().hex[:8]
|
|
title = payload.title
|
|
if not title and payload.user_query :
|
|
title = payload.user_query[0:50]
|
|
|
|
# ---------------------------
|
|
answer = {}
|
|
is_gpu_service_ready = False
|
|
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
|
try:
|
|
if settings.ai_rag_host_gpu :
|
|
url = settings.ai_rag_host_gpu + '/ping'
|
|
response = requests.request("GET", url, headers=headers)
|
|
# print(response)
|
|
# print(response.status_code)
|
|
if response.status_code == 200 :
|
|
is_gpu_service_ready = True
|
|
else :
|
|
is_gpu_service_ready = False
|
|
except Exception as exc: # noqa: BLE001
|
|
is_gpu_service_ready = False
|
|
|
|
print( " settings.ai_rag_host_gpu ", settings.ai_rag_host_gpu, is_gpu_service_ready)
|
|
try:
|
|
if payload.user_query :
|
|
|
|
if is_gpu_service_ready :
|
|
url = settings.ai_rag_host_gpu + "/" + "run_chat"
|
|
elif settings.ai_rag_host :
|
|
url = settings.ai_rag_host + "/" + "run_chatbot"
|
|
else :
|
|
print(" ******* error in settings.ai_rag_host ...")
|
|
|
|
print("settings.ai_rag_host ...", url)
|
|
body = {"query": payload.user_query, "text":""}
|
|
response = requests.request("POST", url, headers=headers, json=body)
|
|
# print(response.text)
|
|
# print(response.json())
|
|
print(response.status_code)
|
|
# print(response.headers)
|
|
|
|
# print("body ", body)
|
|
# print("response ", response)
|
|
# print(response.text)
|
|
data = response.json()
|
|
# print(data)
|
|
answer = data.get("answer", {} )
|
|
except Exception as exc: # noqa: BLE001
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
# ---------------------------
|
|
|
|
|
|
time_stamp = int(datetime.now().timestamp())
|
|
user_id = 0
|
|
document = {
|
|
'id' : id,
|
|
'title' : title,
|
|
'chat_id' : chat_id,
|
|
'user_id' : user_id,
|
|
'user_query' : payload.user_query,
|
|
'model_key' : answer.get("model_key", ""),
|
|
'retrived_passage' : answer.get("retrived_passage", ""),
|
|
'retrived_ref_ids' : answer.get("retrived_ref_ids", []),
|
|
'retrived_duration' : int(answer.get("retrived_duration", 0)),
|
|
'prompt_type' : answer.get("prompt_type", ""),
|
|
'llm_duration' : int(answer.get("llm_duration", 0)),
|
|
'full_duration' : int(answer.get("full_duration", 0)),
|
|
'time_create' : time_stamp,
|
|
'used_ref_ids' : answer.get("used_ref_ids", []),
|
|
'prompt_answer' : answer.get("prompt_answer", ""),
|
|
'status_text' : answer.get("status_text", ""),
|
|
'status' : answer.get("status", False)
|
|
|
|
}
|
|
insertRequest = InsertRequest(id=id, document=document)
|
|
response = await insert(type_name, insertRequest, request )
|
|
response['answer'] = document
|
|
return response
|