اصلاح ورودی ها

This commit is contained in:
ajokar 2025-03-08 16:18:15 +03:30
parent 2f74923291
commit 752554deba

34
main.py
View File

@ -10,6 +10,7 @@ import hazm
from cleantext import clean
import re
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
start_time = datetime.datetime.now()
print('start: ' + str(start_time))
@ -211,13 +212,19 @@ You are given the extracted parts of a long document and a question. Provide a c
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
def format_prompt(prompt, retrieved_documents, k) -> str:
def format_prompt_0(prompt, retrieved_documents, k) -> str:
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents[idx]}\n"
return PROMPT
def format_prompt(prompt, text) -> str:
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
PROMPT+= f"{text}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
@ -240,23 +247,34 @@ def generate(formatted_prompt):
def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
formatted_prompt = format_prompt_0(prompt,retrieved_documents,k)
for item in retrieved_documents['id']:
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
return generate(formatted_prompt)
#result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
class TextList(BaseModel):
text: list[str]
query: str
text: str
app = FastAPI()
@app.post("/rag_chat/")
async def rag_chat(prompt: str, input_texts: TextList):
# origins = ['*']
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
@app.post("/rag_chat")
async def rag_chat(input: TextList):
try:
input_texts= []
prompt = ''
formatted_prompt = format_prompt(prompt, input_texts,len(input_texts))
formatted_prompt = format_prompt(input.query, input.text)
except Exception as e:
return {"status": "error", "message": str(e)}