اصلاح ورودی ها

This commit is contained in:
ajokar 2025-03-08 16:18:15 +03:30
parent 2f74923291
commit 752554deba

34
main.py
View File

@ -10,6 +10,7 @@ import hazm
from cleantext import clean from cleantext import clean
import re import re
from pydantic import BaseModel from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
print('start: ' + str(start_time)) print('start: ' + str(start_time))
@ -211,13 +212,19 @@ You are given the extracted parts of a long document and a question. Provide a c
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد""" If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
def format_prompt(prompt, retrieved_documents, k) -> str: def format_prompt_0(prompt, retrieved_documents, k) -> str:
"""using the retrieved documents we will prompt the model to generate our responses""" """using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:" PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) : for idx in range(k) :
PROMPT+= f"{retrieved_documents[idx]}\n" PROMPT+= f"{retrieved_documents[idx]}\n"
return PROMPT return PROMPT
def format_prompt(prompt, text) -> str:
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
PROMPT+= f"{text}\n"
return PROMPT
def generate(formatted_prompt): def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}] messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
@ -240,23 +247,34 @@ def generate(formatted_prompt):
def rag_chatbot(prompt:str,k:int=2): def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k) scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt,retrieved_documents,k) formatted_prompt = format_prompt_0(prompt,retrieved_documents,k)
for item in retrieved_documents['id']: for item in retrieved_documents['id']:
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item) print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
return generate(formatted_prompt) return generate(formatted_prompt)
#result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5) #result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
class TextList(BaseModel): class TextList(BaseModel):
text: list[str] query: str
text: str
app = FastAPI() app = FastAPI()
@app.post("/rag_chat/")
async def rag_chat(prompt: str, input_texts: TextList): # origins = ['*']
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
@app.post("/rag_chat")
async def rag_chat(input: TextList):
try: try:
input_texts= [] formatted_prompt = format_prompt(input.query, input.text)
prompt = ''
formatted_prompt = format_prompt(prompt, input_texts,len(input_texts))
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}