diff --git a/main.py b/main.py index c0a6d6a..b2795c2 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ import hazm from cleantext import clean import re from pydantic import BaseModel +from fastapi.middleware.cors import CORSMiddleware start_time = datetime.datetime.now() print('start: ' + str(start_time)) @@ -211,13 +212,19 @@ You are given the extracted parts of a long document and a question. Provide a c If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد""" -def format_prompt(prompt, retrieved_documents, k) -> str: +def format_prompt_0(prompt, retrieved_documents, k) -> str: """using the retrieved documents we will prompt the model to generate our responses""" PROMPT = f"Question:{prompt}\nContext:" for idx in range(k) : PROMPT+= f"{retrieved_documents[idx]}\n" return PROMPT +def format_prompt(prompt, text) -> str: + """using the retrieved documents we will prompt the model to generate our responses""" + PROMPT = f"Question:{prompt}\nContext:" + PROMPT+= f"{text}\n" + return PROMPT + def generate(formatted_prompt): formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}] @@ -240,23 +247,34 @@ def generate(formatted_prompt): def rag_chatbot(prompt:str,k:int=2): scores , retrieved_documents = search(prompt, k) - formatted_prompt = format_prompt(prompt,retrieved_documents,k) + formatted_prompt = format_prompt_0(prompt,retrieved_documents,k) for item in retrieved_documents['id']: print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item) return generate(formatted_prompt) #result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5) class TextList(BaseModel): - text: list[str] + query: str + text: str app = FastAPI() -@app.post("/rag_chat/") -async def rag_chat(prompt: str, input_texts: TextList): + +# origins = ['*'] + +# app.add_middleware( +# CORSMiddleware, +# allow_origins=origins, +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) + + +@app.post("/rag_chat") +async def rag_chat(input: TextList): try: - input_texts= [] - prompt = '' - formatted_prompt = format_prompt(prompt, input_texts,len(input_texts)) + formatted_prompt = format_prompt(input.query, input.text) except Exception as e: return {"status": "error", "message": str(e)}