# Copyright 2018, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from collections import ChainMap
from typing import Optional, Tuple, List

# Third-party imports
import numpy as np
from mxnet import gluon

# First-party imports
from gluonts.distribution import Distribution
from gluonts.distribution.bijection import AffineTransformation
from gluonts.distribution.bijection_output import BijectionOutput
from gluonts.distribution.distribution_output import (
from gluonts.distribution.transformed_distribution import (
from gluonts.model.common import Tensor
from gluonts.core.component import validated

[docs]class TransformedDistributionOutput(DistributionOutput): r""" Class to connect a network to a distribution that is transformed by a sequence of learnable bijections. """ @validated() def __init__( self, base_distr_output: DistributionOutput, transforms_output: List[BijectionOutput], ) -> None: super().__init__() self.base_distr_output = base_distr_output self.transforms_output = transforms_output self.base_distr_args_dim = base_distr_output.args_dim self.transforms_args_dim = [ transform.args_dim for transform in transforms_output ] def _fuse(t1: Tuple, t2: Tuple) -> Tuple: if len(t1) > len(t2): t1, t2 = t2, t1 # from here on len(t2) >= len(t1) assert t2[-len(t1) :] == t1 return t2 self._event_shape: Tuple[int, ...] = () for to in self.transforms_output: self._event_shape = _fuse(self._event_shape, to.event_shape)
[docs] def get_args_proj(self, prefix: Optional[str] = None) -> ArgProj: return ArgProj( args_dim=dict( self.base_distr_args_dim, **dict(ChainMap(*self.transforms_args_dim)), ), domain_map=gluon.nn.HybridLambda(self.domain_map), prefix=prefix, )
def _split_args(self, args): # Since hybrid_forward does not support dictionary, # we have to separate the raw outputs of the network based on the indices # and map them to the learnable parameters num_distr_args = len(self.base_distr_args_dim) distr_args = args[0:num_distr_args] num_transforms_args = [ len(transform_dim_args) for transform_dim_args in self.transforms_args_dim ] # starting indices of arguments for each transformation num_args_cumsum = np.cumsum([num_distr_args] + num_transforms_args) # get the arguments for each of the transformations transforms_args = list( map( lambda ixs: args[ixs[0] : ixs[1]], zip(num_args_cumsum, num_args_cumsum[1:]), ) ) return distr_args, transforms_args
[docs] def domain_map(self, F, *args: Tensor): distr_args, transforms_args = self._split_args(args) distr_params = self.base_distr_output.domain_map(F, *distr_args) transforms_params = [ transform_output.domain_map(F, *transform_args) for transform_output, transform_args in zip( self.transforms_output, transforms_args ) ] # flatten the nested tuple return sum(tuple([distr_params] + transforms_params), ())
[docs] def distribution( self, distr_args, scale: Optional[Tensor] = None, **kwargs ) -> Distribution: distr_args, transforms_args = self._split_args(distr_args) distr = self.base_distr_output.distr_cls(*distr_args) transforms = [ transform_output.bij_cls(*bij_args) for transform_output, bij_args in zip( self.transforms_output, transforms_args ) ] trans_distr = TransformedDistribution(distr, transforms) # Apply scaling as well at the end if scale is not None! if scale is None: return trans_distr else: return TransformedDistribution( trans_distr, AffineTransformation(scale=scale) )
@property def event_shape(self) -> Tuple: return self._event_shape