200 lines
9.1 KiB
Python
200 lines
9.1 KiB
Python
from __future__ import annotations
|
|
import uvicorn
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from starlette.status import HTTP_201_CREATED
|
|
|
|
from src.database_handler import Database
|
|
from src.requests_data import Requests
|
|
from model.request_models import InputNerType, InputSummaryType, InputNREType
|
|
from model.response_models import BaseResponse, ResponseNerModel, ResponseSummaryModel, ResponseOIEModel, \
|
|
ResponseNREModel
|
|
from src.response_ner import NerResponse, ResponsesNer
|
|
from fastapi import BackgroundTasks
|
|
|
|
from src.response_nre import NREResponse, ResponsesNRE
|
|
from src.response_oie import OIEResponse, ResponsesOIE
|
|
from src.response_summary import ShortResponse, LongResponse, ResponsesSummary, NormalizerTexts
|
|
|
|
print("torch version : ", torch.__version__)
|
|
app = FastAPI()
|
|
# ner_model_name = os.getenv("NER_MODEL_NAME")
|
|
# normalize = os.getenv("NORMALIZE")
|
|
# short_model_name = os.getenv("SHORT_MODEL_NAME")
|
|
# long_model_name = os.getenv("LONG_MODEL_NAME")
|
|
# pos_model_name = os.getenv("POS_MODEL_NAME")
|
|
# batch_size = int(os.getenv("BATCH_SIZE"))
|
|
# device = os.getenv("DEVICE")
|
|
#
|
|
# print(ner_model_name, normalize, short_model_name, long_model_name, pos_model_name, batch_size, device)
|
|
|
|
# initial docker parameters
|
|
pos_model_name = "./data/pos-model.pt"
|
|
ner_model_name = "./data/ner-model.pt"
|
|
short_model_name = "csebuetnlp/mT5_multilingual_XLSum"
|
|
long_model_name = "alireza7/ARMAN-MSR-persian-base-PN-summary"
|
|
# oie_model_name = "default"
|
|
oie_model_name = ""
|
|
ner_model_name = "default"
|
|
pos_model_name = "default"
|
|
nre_model_name = ""
|
|
|
|
bert_oie = './data/mbert-base-parsinlu-entailment'
|
|
short_model_name = ""
|
|
long_model_name = ""
|
|
normalize = "False"
|
|
batch_size = 4
|
|
device = "cpu"
|
|
|
|
if ner_model_name == "":
|
|
ner_response = None
|
|
else:
|
|
if ner_model_name.lower() == "default":
|
|
ner_model_name = "./data/ner-model.pt"
|
|
ner_response = NerResponse(ner_model_name, batch_size=batch_size, device=device)
|
|
|
|
if oie_model_name == "":
|
|
oie_response = None
|
|
else:
|
|
if oie_model_name.lower() == "default":
|
|
oie_model_name = "./data/mbertEntail-2GBEn.bin"
|
|
oie_response = OIEResponse(oie_model_name, BERT=bert_oie, batch_size=batch_size, device=device)
|
|
|
|
if nre_model_name == "":
|
|
nre_response = None
|
|
else:
|
|
if nre_model_name.lower() == "default":
|
|
nre_model_name = "./data/model.ckpt.pth.tar"
|
|
nre_response = NREResponse(nre_model_name)
|
|
|
|
if normalize.lower() == "false" or pos_model_name == "":
|
|
short_response = ShortResponse(normalizer_input=None, device=device)
|
|
long_response = LongResponse(normalizer_input=None, device=device)
|
|
else:
|
|
if pos_model_name.lower() == "default":
|
|
pos_model_name = "./data/pos-model.pt"
|
|
normalizer_pos = NormalizerTexts(pos_model_name)
|
|
short_response = ShortResponse(normalizer_input=normalizer_pos, device=device)
|
|
long_response = LongResponse(normalizer_input=normalizer_pos, device=device)
|
|
if short_model_name != "":
|
|
short_response.load_model(short_model_name)
|
|
if long_model_name != "":
|
|
long_response.load_model(long_model_name)
|
|
|
|
requests: Requests = Requests()
|
|
database: Database = Database()
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Hello World"}
|
|
|
|
|
|
@app.get("/ner/{item_id}", response_model=ResponseNerModel)
|
|
async def read_item(item_id, request: Request) -> ResponseNerModel:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and ner_response:
|
|
if not ResponsesNer.check_id(int(item_id)):
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
return ResponsesNer.get_response(int(item_id))
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
|
|
|
|
@app.get("/oie/{item_id}", response_model=ResponseOIEModel)
|
|
async def read_item(item_id, request: Request) -> ResponseOIEModel:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and oie_response:
|
|
if not ResponsesOIE.check_id(int(item_id)):
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
return ResponsesOIE.get_response(int(item_id))
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
|
|
|
|
@app.get("/nre/{item_id}", response_model=ResponseNREModel)
|
|
async def read_item(item_id, request: Request) -> ResponseNREModel:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and nre_response:
|
|
if not ResponsesNRE.check_id(int(item_id)):
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
return ResponsesNRE.get_response(int(item_id))
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
|
|
|
|
@app.get("/summary/{item_id}", response_model=ResponseSummaryModel)
|
|
async def read_item(item_id, request: Request) -> ResponseSummaryModel:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')):
|
|
if not ResponsesSummary.check_id(int(item_id)):
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
return ResponsesSummary.get_response(int(item_id))
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
|
|
|
|
@app.post("/ner", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
async def ner(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and ner_model_name != "":
|
|
request_id = requests.add_request_ner(input_json)
|
|
background_tasks.add_task(ner_response.predict_json, input_json, request_id)
|
|
return BaseResponse(status="input received successfully", id=request_id)
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
# return ner_response.predict_json(input_json)
|
|
|
|
|
|
@app.post("/oie", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
async def oie(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and oie_model_name != "":
|
|
request_id = requests.add_request_oie(input_json)
|
|
background_tasks.add_task(oie_response.predict_json, input_json, request_id)
|
|
return BaseResponse(status="input received successfully", id=request_id)
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
# return ner_response.predict_json(input_json)
|
|
|
|
|
|
@app.post("/nre", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
async def oie(input_json: InputNREType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and nre_model_name != "":
|
|
request_id = requests.add_request_nre(input_json)
|
|
background_tasks.add_task(nre_response.predict_json, input_json, request_id)
|
|
return BaseResponse(status="input received successfully", id=request_id)
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
# return ner_response.predict_json(input_json)
|
|
|
|
|
|
@app.post("/summary/short", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
async def short_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
|
|
request: Request) -> BaseResponse:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and short_model_name != "":
|
|
request_id = requests.add_request_summary(input_json)
|
|
background_tasks.add_task(short_response.get_result, input_json, request_id)
|
|
return BaseResponse(status="input received successfully", id=request_id)
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
# return ner_response.predict_json(input_json)
|
|
|
|
|
|
@app.post("/summary/long", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
async def long_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
|
|
request: Request) -> BaseResponse:
|
|
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
client_password=request.headers.get('Client-Password')) and long_model_name != "":
|
|
request_id = requests.add_request_summary(input_json)
|
|
background_tasks.add_task(long_response.get_result, input_json, request_id)
|
|
return BaseResponse(status="input received successfully", id=request_id)
|
|
else:
|
|
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|