NLP_tutorial/3-NLP_services/main.py
2025-04-09 09:39:40 +03:30

200 lines
9.1 KiB
Python

from __future__ import annotations
import uvicorn
import torch
from fastapi import FastAPI, HTTPException, Request
from starlette.status import HTTP_201_CREATED
from src.database_handler import Database
from src.requests_data import Requests
from model.request_models import InputNerType, InputSummaryType, InputNREType
from model.response_models import BaseResponse, ResponseNerModel, ResponseSummaryModel, ResponseOIEModel, \
ResponseNREModel
from src.response_ner import NerResponse, ResponsesNer
from fastapi import BackgroundTasks
from src.response_nre import NREResponse, ResponsesNRE
from src.response_oie import OIEResponse, ResponsesOIE
from src.response_summary import ShortResponse, LongResponse, ResponsesSummary, NormalizerTexts
print("torch version : ", torch.__version__)
app = FastAPI()
# ner_model_name = os.getenv("NER_MODEL_NAME")
# normalize = os.getenv("NORMALIZE")
# short_model_name = os.getenv("SHORT_MODEL_NAME")
# long_model_name = os.getenv("LONG_MODEL_NAME")
# pos_model_name = os.getenv("POS_MODEL_NAME")
# batch_size = int(os.getenv("BATCH_SIZE"))
# device = os.getenv("DEVICE")
#
# print(ner_model_name, normalize, short_model_name, long_model_name, pos_model_name, batch_size, device)
# initial docker parameters
pos_model_name = "./data/pos-model.pt"
ner_model_name = "./data/ner-model.pt"
short_model_name = "csebuetnlp/mT5_multilingual_XLSum"
long_model_name = "alireza7/ARMAN-MSR-persian-base-PN-summary"
# oie_model_name = "default"
oie_model_name = ""
ner_model_name = "default"
pos_model_name = "default"
nre_model_name = ""
bert_oie = './data/mbert-base-parsinlu-entailment'
short_model_name = ""
long_model_name = ""
normalize = "False"
batch_size = 4
device = "cpu"
if ner_model_name == "":
ner_response = None
else:
if ner_model_name.lower() == "default":
ner_model_name = "./data/ner-model.pt"
ner_response = NerResponse(ner_model_name, batch_size=batch_size, device=device)
if oie_model_name == "":
oie_response = None
else:
if oie_model_name.lower() == "default":
oie_model_name = "./data/mbertEntail-2GBEn.bin"
oie_response = OIEResponse(oie_model_name, BERT=bert_oie, batch_size=batch_size, device=device)
if nre_model_name == "":
nre_response = None
else:
if nre_model_name.lower() == "default":
nre_model_name = "./data/model.ckpt.pth.tar"
nre_response = NREResponse(nre_model_name)
if normalize.lower() == "false" or pos_model_name == "":
short_response = ShortResponse(normalizer_input=None, device=device)
long_response = LongResponse(normalizer_input=None, device=device)
else:
if pos_model_name.lower() == "default":
pos_model_name = "./data/pos-model.pt"
normalizer_pos = NormalizerTexts(pos_model_name)
short_response = ShortResponse(normalizer_input=normalizer_pos, device=device)
long_response = LongResponse(normalizer_input=normalizer_pos, device=device)
if short_model_name != "":
short_response.load_model(short_model_name)
if long_model_name != "":
long_response.load_model(long_model_name)
requests: Requests = Requests()
database: Database = Database()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/ner/{item_id}", response_model=ResponseNerModel)
async def read_item(item_id, request: Request) -> ResponseNerModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and ner_response:
if not ResponsesNer.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesNer.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.get("/oie/{item_id}", response_model=ResponseOIEModel)
async def read_item(item_id, request: Request) -> ResponseOIEModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and oie_response:
if not ResponsesOIE.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesOIE.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.get("/nre/{item_id}", response_model=ResponseNREModel)
async def read_item(item_id, request: Request) -> ResponseNREModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and nre_response:
if not ResponsesNRE.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesNRE.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.get("/summary/{item_id}", response_model=ResponseSummaryModel)
async def read_item(item_id, request: Request) -> ResponseSummaryModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')):
if not ResponsesSummary.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesSummary.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.post("/ner", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def ner(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and ner_model_name != "":
request_id = requests.add_request_ner(input_json)
background_tasks.add_task(ner_response.predict_json, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/oie", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def oie(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and oie_model_name != "":
request_id = requests.add_request_oie(input_json)
background_tasks.add_task(oie_response.predict_json, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/nre", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def oie(input_json: InputNREType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and nre_model_name != "":
request_id = requests.add_request_nre(input_json)
background_tasks.add_task(nre_response.predict_json, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/summary/short", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def short_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and short_model_name != "":
request_id = requests.add_request_summary(input_json)
background_tasks.add_task(short_response.get_result, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/summary/long", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def long_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and long_model_name != "":
request_id = requests.add_request_summary(input_json)
background_tasks.add_task(long_response.get_result, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)