اصلاح ورودی ها
This commit is contained in:
parent
2f74923291
commit
752554deba
34
main.py
34
main.py
|
@ -10,6 +10,7 @@ import hazm
|
|||
from cleantext import clean
|
||||
import re
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
print('start: ' + str(start_time))
|
||||
|
@ -211,13 +212,19 @@ You are given the extracted parts of a long document and a question. Provide a c
|
|||
If you don't know the answer, just say "I do not know." Don't make up an answer.لطفا جواب فارسی باشد"""
|
||||
|
||||
|
||||
def format_prompt(prompt, retrieved_documents, k) -> str:
|
||||
def format_prompt_0(prompt, retrieved_documents, k) -> str:
|
||||
"""using the retrieved documents we will prompt the model to generate our responses"""
|
||||
PROMPT = f"Question:{prompt}\nContext:"
|
||||
for idx in range(k) :
|
||||
PROMPT+= f"{retrieved_documents[idx]}\n"
|
||||
return PROMPT
|
||||
|
||||
def format_prompt(prompt, text) -> str:
|
||||
"""using the retrieved documents we will prompt the model to generate our responses"""
|
||||
PROMPT = f"Question:{prompt}\nContext:"
|
||||
PROMPT+= f"{text}\n"
|
||||
return PROMPT
|
||||
|
||||
def generate(formatted_prompt):
|
||||
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
|
||||
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
||||
|
@ -240,23 +247,34 @@ def generate(formatted_prompt):
|
|||
|
||||
def rag_chatbot(prompt:str,k:int=2):
|
||||
scores , retrieved_documents = search(prompt, k)
|
||||
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
|
||||
formatted_prompt = format_prompt_0(prompt,retrieved_documents,k)
|
||||
for item in retrieved_documents['id']:
|
||||
print('https://majles.tavasi.ir/entity/detail/view/qsection/' + item)
|
||||
return generate(formatted_prompt)
|
||||
#result = rag_chatbot("نوبت کاری شبانه برای مادران باردار و مادران دارای فرزند شیرخوار به چه صورت است؟", k = 5)
|
||||
|
||||
class TextList(BaseModel):
|
||||
text: list[str]
|
||||
query: str
|
||||
text: str
|
||||
|
||||
app = FastAPI()
|
||||
@app.post("/rag_chat/")
|
||||
async def rag_chat(prompt: str, input_texts: TextList):
|
||||
|
||||
# origins = ['*']
|
||||
|
||||
# app.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=origins,
|
||||
# allow_credentials=True,
|
||||
# allow_methods=["*"],
|
||||
# allow_headers=["*"],
|
||||
# )
|
||||
|
||||
|
||||
@app.post("/rag_chat")
|
||||
async def rag_chat(input: TextList):
|
||||
|
||||
try:
|
||||
input_texts= []
|
||||
prompt = ''
|
||||
formatted_prompt = format_prompt(prompt, input_texts,len(input_texts))
|
||||
formatted_prompt = format_prompt(input.query, input.text)
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user