from __future__ import annotations import uvicorn import torch from fastapi import FastAPI, HTTPException, Request from starlette.status import HTTP_201_CREATED from src.database_handler import Database from src.requests_data import Requests from model.request_models import InputNerType, InputSummaryType, InputNREType from model.response_models import BaseResponse, ResponseNerModel, ResponseSummaryModel, ResponseOIEModel, \ ResponseNREModel from src.response_ner import NerResponse, ResponsesNer from fastapi import BackgroundTasks from src.response_nre import NREResponse, ResponsesNRE from src.response_oie import OIEResponse, ResponsesOIE from src.response_summary import ShortResponse, LongResponse, ResponsesSummary, NormalizerTexts print("torch version : ", torch.__version__) app = FastAPI() # ner_model_name = os.getenv("NER_MODEL_NAME") # normalize = os.getenv("NORMALIZE") # short_model_name = os.getenv("SHORT_MODEL_NAME") # long_model_name = os.getenv("LONG_MODEL_NAME") # pos_model_name = os.getenv("POS_MODEL_NAME") # batch_size = int(os.getenv("BATCH_SIZE")) # device = os.getenv("DEVICE") # # print(ner_model_name, normalize, short_model_name, long_model_name, pos_model_name, batch_size, device) # initial docker parameters pos_model_name = "./data/pos-model.pt" ner_model_name = "./data/ner-model.pt" short_model_name = "csebuetnlp/mT5_multilingual_XLSum" long_model_name = "alireza7/ARMAN-MSR-persian-base-PN-summary" # oie_model_name = "default" oie_model_name = "" ner_model_name = "default" pos_model_name = "default" nre_model_name = "" bert_oie = './data/mbert-base-parsinlu-entailment' short_model_name = "" long_model_name = "" normalize = "False" batch_size = 4 device = "cpu" if ner_model_name == "": ner_response = None else: if ner_model_name.lower() == "default": ner_model_name = "./data/ner-model.pt" ner_response = NerResponse(ner_model_name, batch_size=batch_size, device=device) if oie_model_name == "": oie_response = None else: if oie_model_name.lower() == "default": oie_model_name = "./data/mbertEntail-2GBEn.bin" oie_response = OIEResponse(oie_model_name, BERT=bert_oie, batch_size=batch_size, device=device) if nre_model_name == "": nre_response = None else: if nre_model_name.lower() == "default": nre_model_name = "./data/model.ckpt.pth.tar" nre_response = NREResponse(nre_model_name) if normalize.lower() == "false" or pos_model_name == "": short_response = ShortResponse(normalizer_input=None, device=device) long_response = LongResponse(normalizer_input=None, device=device) else: if pos_model_name.lower() == "default": pos_model_name = "./data/pos-model.pt" normalizer_pos = NormalizerTexts(pos_model_name) short_response = ShortResponse(normalizer_input=normalizer_pos, device=device) long_response = LongResponse(normalizer_input=normalizer_pos, device=device) if short_model_name != "": short_response.load_model(short_model_name) if long_model_name != "": long_response.load_model(long_model_name) requests: Requests = Requests() database: Database = Database() @app.get("/") async def root(): return {"message": "Hello World"} @app.get("/ner/{item_id}", response_model=ResponseNerModel) async def read_item(item_id, request: Request) -> ResponseNerModel: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and ner_response: if not ResponsesNer.check_id(int(item_id)): raise HTTPException(status_code=404, detail="Item not found") return ResponsesNer.get_response(int(item_id)) else: raise HTTPException(status_code=403, detail="you are not allowed") @app.get("/oie/{item_id}", response_model=ResponseOIEModel) async def read_item(item_id, request: Request) -> ResponseOIEModel: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and oie_response: if not ResponsesOIE.check_id(int(item_id)): raise HTTPException(status_code=404, detail="Item not found") return ResponsesOIE.get_response(int(item_id)) else: raise HTTPException(status_code=403, detail="you are not allowed") @app.get("/nre/{item_id}", response_model=ResponseNREModel) async def read_item(item_id, request: Request) -> ResponseNREModel: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and nre_response: if not ResponsesNRE.check_id(int(item_id)): raise HTTPException(status_code=404, detail="Item not found") return ResponsesNRE.get_response(int(item_id)) else: raise HTTPException(status_code=403, detail="you are not allowed") @app.get("/summary/{item_id}", response_model=ResponseSummaryModel) async def read_item(item_id, request: Request) -> ResponseSummaryModel: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')): if not ResponsesSummary.check_id(int(item_id)): raise HTTPException(status_code=404, detail="Item not found") return ResponsesSummary.get_response(int(item_id)) else: raise HTTPException(status_code=403, detail="you are not allowed") @app.post("/ner", response_model=BaseResponse, status_code=HTTP_201_CREATED) async def ner(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and ner_model_name != "": request_id = requests.add_request_ner(input_json) background_tasks.add_task(ner_response.predict_json, input_json, request_id) return BaseResponse(status="input received successfully", id=request_id) else: raise HTTPException(status_code=403, detail="you are not allowed") # return ner_response.predict_json(input_json) @app.post("/oie", response_model=BaseResponse, status_code=HTTP_201_CREATED) async def oie(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and oie_model_name != "": request_id = requests.add_request_oie(input_json) background_tasks.add_task(oie_response.predict_json, input_json, request_id) return BaseResponse(status="input received successfully", id=request_id) else: raise HTTPException(status_code=403, detail="you are not allowed") # return ner_response.predict_json(input_json) @app.post("/nre", response_model=BaseResponse, status_code=HTTP_201_CREATED) async def oie(input_json: InputNREType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and nre_model_name != "": request_id = requests.add_request_nre(input_json) background_tasks.add_task(nre_response.predict_json, input_json, request_id) return BaseResponse(status="input received successfully", id=request_id) else: raise HTTPException(status_code=403, detail="you are not allowed") # return ner_response.predict_json(input_json) @app.post("/summary/short", response_model=BaseResponse, status_code=HTTP_201_CREATED) async def short_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and short_model_name != "": request_id = requests.add_request_summary(input_json) background_tasks.add_task(short_response.get_result, input_json, request_id) return BaseResponse(status="input received successfully", id=request_id) else: raise HTTPException(status_code=403, detail="you are not allowed") # return ner_response.predict_json(input_json) @app.post("/summary/long", response_model=BaseResponse, status_code=HTTP_201_CREATED) async def long_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse: if database.check_client(client_id=request.headers.get('Client-Id'), client_password=request.headers.get('Client-Password')) and long_model_name != "": request_id = requests.add_request_summary(input_json) background_tasks.add_task(long_response.get_result, input_json, request_id) return BaseResponse(status="input received successfully", id=request_id) else: raise HTTPException(status_code=403, detail="you are not allowed") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)