Source code for

# Copyright 2018, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
import logging
import multiprocessing
from ipaddress import IPv4Address
from typing import Optional, Type, Union

# Third-party imports
from flask import Flask
from import BaseApplication
from pydantic import BaseSettings

# First-party imports
import gluonts
from gluonts.core import fqname_for
from gluonts.core.component import check_gpu_support
from gluonts.model.estimator import Estimator
from gluonts.model.predictor import Predictor
from import ServeEnv

from .app import make_app

    format="%(asctime)s [%(levelname)s] %(name)s %(message)s",
    datefmt="[%Y-%m-%d %H:%M:%S]",
logger = logging.getLogger(__name__)

MB = 1024 * 1024

[docs]class Settings(BaseSettings): # see:
[docs] class Config: env_prefix = ""
model_server_workers: Optional[int] = None max_content_length: int = 6 * MB sagemaker_server_address: IPv4Address = IPv4Address("") sagemaker_server_port: int = 8080 sagemaker_server_timeout: int = 100 sagemaker_batch: bool = False sagemaker_batch_strategy: str = "SINGLE_RECORD" sagemaker_max_payload_in_mb: int = 6 sagemaker_max_concurrent_transforms: int = 2 ** 32 - 1 @property def sagemaker_server_bind(self) -> str: return f"{self.sagemaker_server_address}:{self.sagemaker_server_port}" @property def number_of_workers(self) -> int: cpu_count = multiprocessing.cpu_count() if self.model_server_workers: f"Using {self.model_server_workers} workers " f"(set by MODEL_SERVER_WORKERS environment variable)." ) return self.model_server_workers elif ( self.sagemaker_batch and self.sagemaker_max_concurrent_transforms < cpu_count ): f"Using {self.sagemaker_max_concurrent_transforms} workers " f"(set by MaxConcurrentTransforms parameter in batch mode)." ) return self.sagemaker_max_concurrent_transforms else:"Using {cpu_count} workers") return cpu_count
[docs]class Application(BaseApplication): def __init__(self, app, config) -> None: = app self.config = config BaseApplication.__init__(self)
[docs] def load_config(self) -> None: for key, value in self.config.items(): if key in self.cfg.settings and value is not None: self.cfg.set(key, value)
[docs] def init(self, parser, opts, args): pass
[docs] def load(self) -> Flask: return
[docs] def stop(self, *args, **kwargs):"Shutting down GluonTS scoring service")
[docs]def make_gunicorn_app( env: ServeEnv, forecaster_type: Optional[Type[Union[Estimator, Predictor]]], settings: Settings, ) -> Application: check_gpu_support() if forecaster_type is not None:"Using dynamic predictor factory") ctor = forecaster_type.from_hyperparameters forecaster_fq_name = fqname_for(forecaster_type) forecaster_version = forecaster_type.__version__ def predictor_factory(request) -> Predictor: return ctor(**request["configuration"]) else:"Using static predictor factory") assert env is not None predictor = Predictor.deserialize(env.path.model) forecaster_fq_name = fqname_for(type(predictor)) forecaster_version = predictor.__version__ def predictor_factory(request) -> Predictor: return predictor"Using gluonts v{gluonts.__version__}")"Using forecaster {forecaster_fq_name} v{forecaster_version}") execution_params = { "MaxConcurrentTransforms": settings.number_of_workers, "BatchStrategy": settings.sagemaker_batch_strategy, "MaxPayloadInMB": settings.sagemaker_max_payload_in_mb, } flask_app = make_app( predictor_factory, execution_params, batch_transform_config=env.batch_config, ) gunicorn_app = Application( app=flask_app, config={ "bind": settings.sagemaker_server_bind, "workers": settings.number_of_workers, "timeout": settings.sagemaker_server_timeout, }, ) return gunicorn_app