# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# TODO: Fix typing
# type: ignore


from typing import List, Optional, Tuple, Union, Sequence

# Third-party imports
import mxnet as mx

from gluonts.meta_tools import Reg, Bias
from gluonts.ridge_solver import solve_ridge_regression
from gluonts.time_feature import get_seasonality
from mxnet.gluon.rnn import ZoneoutCell
from mxnet.gluon.contrib.rnn import VariationalDropoutCell

# Standard library imports
import numpy as np

from gluonts.core.component import DType, validated
from gluonts.mx import Tensor

# First-party imports
from gluonts.mx.block.feature import FeatureEmbedder
from gluonts.mx.block.scaler import MeanScaler, NOPScaler
from gluonts.mx.distribution import Distribution
from gluonts.mx.distribution.distribution import getF
from gluonts.mx.util import weighted_average
from gluonts.mx.block.dropout import VariationalZoneoutCell, RNNZoneoutCell
from gluonts.mx.block.regularization import (
    ActivationRegularizationLoss,
    TemporalActivationRegularizationLoss,
)


def is_nd(tensor: Tensor) -> bool:
    return isinstance(tensor, mx.nd.NDArray)


def prod(xs):
    p = 1
    for x in xs:
        p *= x
    return p


class MetaARDetNetwork(mx.gluon.HybridBlock):
    @validated()
    def __init__(
        self,
        num_layers: int,
        num_cells: int,
        output_dim: int,
        cell_type: str,
        history_length: int,
        context_length: int,
        prediction_length: int,
        dropout_rate: float,
        cardinality: List[int],
        embedding_dimension: List[int],
        lags_seq: List[int],
        loss: str = "sMAPE",
        dropoutcell_type: str = "ZoneoutCell",
        scaling: bool = True,
        biased_regularization=True,
        weighted_least_squares=None,
        dtype: DType = np.float32,
        freq: str = "H",
        loss_denom_eps: float = 1.0e-4,
        is_adaptive: bool = True,
        iterated_forecasts: bool = True,
        use_shared_linear: bool = False,
        encoder_type: str = "RNN",
        use_time_features: bool = False,
        use_additional_features: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.num_layers = num_layers
        self.num_cells = num_cells
        self.output_dim = output_dim
        self.cell_type = cell_type
        self.history_length = history_length
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.dropoutcell_type = dropoutcell_type
        self.dropout_rate = dropout_rate
        self.cardinality = cardinality
        self.embedding_dimension = embedding_dimension
        self.num_cat = len(cardinality)
        self.scaling = scaling
        self.dtype = dtype
        self.weighted_least_squares = weighted_least_squares
        self.freq = freq
        self.loss_denom_eps = loss_denom_eps
        self.periodicity = get_seasonality(self.freq)
        self.is_adaptive = is_adaptive
        self.iterated_forecast = iterated_forecasts
        self.use_shared_linear = (
            use_shared_linear
            and is_adaptive
            and not encoder_type in ("identity", "feedforward", "linear")
        )
        self.encoder_type = encoder_type
        self.use_additional_features = use_additional_features
        self.use_time_features = use_time_features

        assert len(cardinality) == len(
            embedding_dimension
        ), "embedding_dimension should be a list with the same size as cardinality"

        assert len(set(lags_seq)) == len(
            lags_seq
        ), "no duplicated lags allowed!"
        lags_seq.sort()

        self.lags_seq = lags_seq

        self.loss = loss
        RnnCell = {"lstm": mx.gluon.rnn.LSTMCell, "gru": mx.gluon.rnn.GRUCell}[
            self.cell_type
        ]

        self.target_shape = ()  # only univariate TS supported for now

        # TODO: is the following restriction needed?
        assert (
            len(self.target_shape) <= 1
        ), "Argument `target_shape` should be a tuple with 1 element at most"

        Dropout = {
            "ZoneoutCell": ZoneoutCell,
            "RNNZoneoutCell": RNNZoneoutCell,
            "VariationalDropoutCell": VariationalDropoutCell,
            "VariationalZoneoutCell": VariationalZoneoutCell,
        }[self.dropoutcell_type]

        # for decoding the lags are shifted by one, at the first time-step
        # of the decoder a lag of one corresponds to the last target value
        self.shifted_lags = [l - 1 for l in self.lags_seq]

        self.out_proj = (
            None
            if self.is_adaptive
            else mx.gluon.nn.Dense(units=1, prefix="output_", flatten=False)
        )

        with self.name_scope():

            if encoder_type == "RNN":
                self.rnn = mx.gluon.rnn.HybridSequentialRNNCell()
                for k in range(num_layers):
                    hidden_size = (
                        output_dim
                        if not self.use_shared_linear
                        else num_cells
                    )
                    cell = RnnCell(hidden_size=hidden_size)
                    cell = mx.gluon.rnn.ResidualCell(cell) if k > 0 else cell
                    # we found that adding dropout to outputs doesn't improve the performance, so we only drop states
                    if "Zoneout" in self.dropoutcell_type:
                        cell = (
                            Dropout(cell, zoneout_states=dropout_rate)
                            if dropout_rate > 0.0
                            else cell
                        )
                    elif "Dropout" in self.dropoutcell_type:
                        cell = (
                            Dropout(cell, drop_states=dropout_rate)
                            if dropout_rate > 0.0
                            else cell
                        )
                    self.rnn.add(cell)
                self.rnn.cast(dtype=dtype)

            if encoder_type == "linear":
                self.transform = mx.gluon.nn.Dense(
                    units=self.output_dim, flatten=False
                )
                self.transform.cast(dtype=dtype)

            elif encoder_type == "feedforward":
                hidden_size = (
                    output_dim
                    if not self.use_shared_linear
                    else num_cells
                )
                self.transform = mx.gluon.nn.HybridSequential()
                for k in range(num_layers):
                    self.transform.add(
                        mx.gluon.nn.Dense(
                            units=hidden_size, activation="relu", flatten=False
                        )
                    )
                self.transform.cast(dtype=dtype)

            elif encoder_type == "identity":
                self.transform = mx.gluon.nn.HybridLambda(lambda F, x: x)
                out_dim = 1  # accounts for the scale
                if use_time_features:
                    out_dim += 2
                if use_additional_features:
                    out_dim += 1
                self.output_dim = (
                    len(self.lags_seq) + out_dim
                )  # works only when no time features and static feat are used
                self.transform.cast(dtype=dtype)

            if self.use_shared_linear:
                self.shared_linear_out = mx.gluon.nn.Dense(
                    output_dim, flatten=False
                )

            self.bias = (
                Bias(0.0, self.output_dim + 1)
                if biased_regularization
                else None
            )
            self.reg = Reg(1.0)

            if self.use_additional_features:
                self.embedder = FeatureEmbedder(
                    cardinalities=cardinality,
                    embedding_dims=embedding_dimension,
                    dtype=self.dtype,
                )
            else:
                self.embedder = lambda x: x

            if scaling:
                self.scaler = MeanScaler(keepdims=True)
            else:
                self.scaler = NOPScaler(keepdims=True)

    def compute_ts_representation(
        self,
        F,
        lags : Tensor,
        static_feat: Tensor,
        time_feat : Optional[Tensor],
        begin_state: Optional[List[Tensor]],
        length: int,
    ):

        # (batch_size, subsequences_length, num_features + 1)
        repeated_static_feat = static_feat.expand_dims(axis=1).repeat(
            axis=1, repeats=length
        )

        # from (batch_size, sub_seq_len, *target_shape, num_lags)
        # to (batch_size, sub_seq_len, prod(target_shape) * num_lags)
        input_lags = F.reshape(
            data=lags,
            shape=(
                -1,
                length,
                len(self.lags_seq) * prod(self.target_shape),
            ),
        )

        # (batch_size, sub_seq_len, input_dim)
        inputs = [input_lags, repeated_static_feat]
        if self.use_time_features:
            inputs.append(time_feat)

        inputs = F.concat(*inputs, dim=-1)

        if self.encoder_type == "RNN":
            if begin_state is None:
                begin_state = self.rnn.begin_state(
                    func=F.zeros,
                    dtype=self.dtype,
                    batch_size=inputs.shape[0]
                    if isinstance(inputs, mx.nd.NDArray)
                    else 0,
                )

            # unroll encoder
            outputs, state = self.rnn.unroll(
                inputs=inputs,
                length=length,
                layout="NTC",
                merge_outputs=True,
                begin_state=begin_state,
            )
        else:
            outputs = self.transform(inputs)
            state = None

        if self.use_shared_linear:
            outputs = self.shared_linear_out(outputs)

        return outputs, state

    @staticmethod
    def get_lagged_subsequences(
        F,
        sequence: Tensor,
        sequence_length: int,
        indices: List[int],
        subsequences_length: int = 1,
    ) -> Tensor:
        """
        Returns lagged subsequences of a given sequence.
        Parameters
        ----------
        sequence : Tensor
            the sequence from which lagged subsequences should be extracted.
            Shape: (N, T, C).
        sequence_length : int
            length of sequence in the T (time) dimension (axis = 1).
        indices : List[int]
            list of lag indices to be used.
        subsequences_length : int
            length of the subsequences to be extracted.
        Returns
        --------
        lagged : Tensor
            a tensor of shape (N, S, C, I), where S = subsequences_length and
            I = len(indices), containing lagged subsequences. Specifically,
            lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
        """
        # we must have: sequence_length - lag_index - subsequences_length >= 0
        # for all lag_index, hence the following assert
        assert max(indices) + subsequences_length <= sequence_length, (
            f"lags cannot go further than history length, "
            f"found lag {max(indices)} while history length is only "
            f"{sequence_length}"
        )
        assert all(lag_index >= 0 for lag_index in indices)
        if is_nd(sequence):
            assert (
                sequence.shape[1] == sequence_length
            ), f"{sequence.shape[1]} == {sequence_length}"

        lagged_values = []
        for lag_index in indices:
            begin_index = -lag_index - subsequences_length
            end_index = -lag_index if lag_index > 0 else None
            lagged_values.append(
                F.slice_axis(
                    sequence, axis=1, begin=begin_index, end=end_index
                )
            )

        return F.stack(*lagged_values, axis=-1)

    def run_encoder(
        self,
        F,
        feat_static_cat: Tensor,  # (batch_size, num_features)
        feat_static_real: Tensor,  # (batch_size, num_features)
        past_time_feat: Tensor,  # (batch_size, history_length, num_features)
        past_target: Tensor,  # (batch_size, history_length, *target_shape)
        past_observed_values: Tensor,  # (batch_size, history_length, *target_shape)
        future_time_feat: Optional[
            Tensor
        ],  # (batch_size, prediction_length, num_features)
        future_target: Optional[
            Tensor
        ],  # (batch_size, prediction_length, *target_shape)
        scale: Tensor,
    ) -> Tuple[Tensor, List, Tensor]:
        """
        Unrolls the encoder over past and, if present, future data.
        Returns outputs and state of the encoder, plus the scale of past_target
        and a vector of static features that was constructed and fed as input
        to the encoder.
        All tensor arguments should have NTC layout.
        """

        if future_time_feat is None or future_target is None:
            time_feat = past_time_feat.slice_axis(
                axis=1,
                begin=self.history_length - self.context_length,
                end=None,
            )
            sequence = past_target
            sequence_length = self.history_length
            subsequences_length = self.context_length
        else:
            time_feat = F.concat(
                past_time_feat.slice_axis(
                    axis=1,
                    begin=self.history_length - self.context_length,
                    end=None,
                ),
                future_time_feat,
                dim=1,
            )
            sequence = F.concat(past_target, future_target, dim=1)
            sequence_length = self.history_length + self.prediction_length
            subsequences_length = self.context_length + self.prediction_length

        # (batch_size, sub_seq_len, *target_shape, num_lags)
        lags_scaled = self.get_lagged_subsequences(
            F=F,
            sequence=sequence,
            sequence_length=sequence_length,
            indices=self.lags_seq,
            subsequences_length=subsequences_length,
        )
        # scale is computed on the context length last units of the past target
        # scale shape is (batch_size, 1, *target_shape)

        # (batch_size, sub_seq_len, *target_shape, num_lags)
        # lags_scaled = F.broadcast_div(lags, scale.expand_dims(axis=-1))

        # (batch_size, num_features)
        embedded_cat = self.embedder(feat_static_cat)

        # in addition to embedding features, use the log scale as it can help
        # prediction too
        static_feat = [
            F.log(scale)
            if len(self.target_shape) == 0
            else F.log(scale.squeeze(axis=1))
        ]

        if self.use_additional_features:
            static_feat.extend([embedded_cat, feat_static_real])

        # (batch_size, num_features + prod(target_shape))
        static_feat = F.concat(*static_feat, dim=1)

        outputs, state = self.compute_ts_representation(
            F, lags=lags_scaled, static_feat=static_feat, time_feat=time_feat,
            begin_state=None, length=subsequences_length
        )

        # outputs: (batch_size, seq_len, num_cells)
        # state: list of (batch_size, num_cells) tensors
        # scale: (batch_size, 1, *target_shape)
        # static_feat: (batch_size, num_features + prod(target_shape))
        return outputs, state, static_feat

    def run_decoder(
        self,
        F,
        static_feat: Tensor,
        past_target: Tensor,
        time_feat: Tensor,
        scale: Tensor,
        begin_states: List,
    ) -> Tuple[Tensor, Tensor]:
        """
        Computes sample paths by unrolling the decoder starting with a initial
        input and state.

        Parameters
        ----------
        static_feat : Tensor
            static features. Shape: (batch_size, num_static_features).
        past_target : Tensor
            target history. Shape: (batch_size, history_length).
        time_feat : Tensor
            time features. Shape: (batch_size, prediction_length, num_time_features).
        scale : Tensor
            tensor containing the scale of each element in the batch. Shape: (batch_size, 1, 1).
        begin_states : List
            list of initial states for the LSTM layers.
            the shape of each tensor of the list should be (batch_size, num_cells)
        Returns
        --------
        Tensor
            A tensor containing sampled paths.
            Shape: (batch_size, num_sample_paths, prediction_length).
        """

        states = (
            [s for s in begin_states] if self.encoder_type == "RNN" else None
        )
        future_samples = []

        # for each future time-units we draw new samples for this time-unit and update the state
        for k in range(self.prediction_length):
            # (batch_size * num_samples, 1, *target_shape, num_lags)
            lags_scaled = self.get_lagged_subsequences(
                F=F,
                sequence=past_target,
                sequence_length=self.history_length + k,
                indices=self.shifted_lags,
                subsequences_length=1,
            )

            # (batch_size * num_samples, 1, *target_shape, num_lags)
            # lags_scaled = F.broadcast_div(lags, scale.expand_dims(axis=-1))

            # output shape: (batch_size * num_samples, 1, num_cells)
            # state shape: (batch_size * num_samples, num_cells)
            outputs, states = self.compute_ts_representation(
                F, lags=lags_scaled, time_feat=time_feat.slice_axis(axis=1, begin=k, end=k + 1),
                static_feat=static_feat,
                begin_state=states, length=1
            )

            # (batch_size * num_samples, 1, *target_shape)
            new_samples = self.out_proj(outputs)[:, :, 0]

            # (batch_size * num_samples, seq_len, *target_shape)
            past_target = F.stop_gradient(F.concat(past_target, new_samples, dim=1))
            future_samples.append(F.broadcast_mul(new_samples, scale))

        # (batch_size * num_samples, prediction_length, *target_shape)
        samples = F.concat(*future_samples, dim=1)

        # (batch_size, num_samples, prediction_length, *target_shape)
        samples = samples.reshape(
            shape=(-1, self.prediction_length) + self.target_shape
        )

        return samples, outputs

    def get_representations_and_targets(
        self,
        F,
        feat_static_cat: Tensor,  # (batch_size, num_features)
        feat_static_real: Tensor,  # (batch_size, num_features)
        past_time_feat: Tensor,  # (batch_size, history_length, num_features)
        past_target: Tensor,  # (batch_size, history_length, *target_shape)
        past_observed_values: Tensor,  # (batch_size, history_length, *target_shape)
        future_time_feat: Tensor,  # (batch_size, prediction_length, num_features)
        future_target: Optional[
            Tensor
        ],  # (batch_size, prediction_length, *target_shape)
        scale: Tensor
    ):

        # outputs: (batch_size, seq_len, num_cells)
        # state: list of (batch_size, num_cells) tensors
        # scale: (batch_size, 1, *target_shape)
        # static_feat: (batch_size, num_features + prod(target_shape))

        future_target = None if self.iterated_forecast else future_target

        outputs, state, static_feat = self.run_encoder(
            F,
            feat_static_cat,
            feat_static_real,
            past_time_feat,
            past_target,
            past_observed_values,
            future_time_feat,
            future_target,
            scale,
        )

        if not self.iterated_forecast:
            target = F.concat(
                past_target[:, -self.context_length :], future_target, dim=1
            )

            observed_outputs = F.concat(
                past_observed_values[:, -self.context_length :],
                F.ones_like(future_target),
            )
        else:
            target = past_target[:, -self.context_length :]
            observed_outputs = past_observed_values[:, -self.context_length :]

        outputs = F.concat(
            F.ones_like(outputs[:, :, 0]).expand_dims(-1), outputs, dim=-1
        )

        outputs = F.broadcast_mul(outputs, observed_outputs.expand_dims(-1))

        return target, outputs, static_feat, state

    def predictions(
        self,
        feat_static_cat: Tensor,
        feat_static_real: Tensor,
        past_time_feat: Tensor,
        past_target: Tensor,
        past_observed_values: Tensor,
        future_time_feat: Tensor,
        future_target: Optional[Tensor],
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """

        Returns the prediction of the model on the range of
        past_target and future_target.

        It is obtained by unrolling the network with the true
        target

        Input arguments are the same as for the hybrid_forward method.

        Returns
        -------
        Tensor
            the prediction vector:
            (batch_size * num_samples, prediction_length).
        Tensor
            (optional) when return_rnn_outputs=True, rnn_outputs will be returned
            so that it could be used for regularization
        """
        # unroll the decoder in "training mode"
        # i.e. by providing future data as well

        F = getF(feat_static_cat)

        _, scale = self.scaler(
            past_target.slice_axis(
                axis=1, begin=-self.context_length, end=None
            ),
            past_observed_values.slice_axis(
                axis=1, begin=-self.context_length, end=None
            ),
        )

        # Scale down the input
        past_target = F.broadcast_div(past_target, scale)
        if future_target is not None:
            future_target = F.broadcast_div(future_target, scale)

        # z are the targets while h are the representations
        z, h, static_feat, state = self.get_representations_and_targets(
            F=F,
            feat_static_cat=feat_static_cat,
            feat_static_real=feat_static_real,
            past_time_feat=past_time_feat,
            past_target=past_target,
            past_observed_values=past_observed_values,
            future_time_feat=future_time_feat,
            future_target=future_target,
            scale=scale,
        )

        h_test, z_test = None, None
        h_train, z_train = h, z

        if not self.iterated_forecast:
            h_train, h_test = (
                h[:, : -self.prediction_length],
                h[:, -self.prediction_length :, 1:],
            )

            z_train, z_test = (
                z[:, : -self.prediction_length],
                z[:, -self.prediction_length :],
            )

        if self.is_adaptive:
            sample_weights = None
            if self.weighted_least_squares == "MAPE":
                sample_weights = F.square(z)
                sample_weights = 1 / (sample_weights + (sample_weights == 0))

            elif self.weighted_least_squares == "MASE":
                seasonal_error = self.compute_mase_denominator(
                    F, z, self.periodicity
                )
                sample_weights = F.square(1.0 / seasonal_error).expand_dims(-1)

            bias = self.bias(F.zeros(1)) if self.bias is not None else None

            w_star = solve_ridge_regression(
                F,
                h_train,
                z_train.expand_dims(-1),
                self.reg(F.zeros(1)),
                "inverse",
                bias=bias,
                sample_weights=sample_weights,
                rescale=False,
            )

            self.build_out_proj(F, w_star[:, 1:], w_star[:, 0])

        if self.iterated_forecast:
            pred, rnn_out = self.run_decoder(
                F,
                static_feat=static_feat,
                past_target=past_target,
                time_feat=future_time_feat,
                scale=scale,
                begin_states=state,
            )

        else:
            # (batch_size * num_samples, 1, *target_shape)
            pred = F.broadcast_mul(self.out_proj(h_test)[:, :, 0], scale)

            rnn_out = h_test

        return pred, scale, rnn_out

    def mse_loss(self, F, forecast: Tensor, future_target: Tensor) -> Tensor:
        return F.square(future_target - forecast)

    def mae_loss(self, F, forecast: Tensor, future_target: Tensor) -> Tensor:
        return F.abs(future_target - forecast)

    def smape_loss(self, F, forecast: Tensor, future_target: Tensor) -> Tensor:
        r"""
        .. math::

            smape = (2/H)*mean(|Y - Y_hat| / (|Y| + |Y_hat|))

        According to paper: https://arxiv.org/abs/1905.10437.
        """

        # Stop gradient required for stable learning
        denominator = F.stop_gradient(
            F.abs(future_target) + F.abs(forecast)
        ).clip(self.loss_denom_eps, np.inf)

        # self.prediction_length should probably be changed to dynamically compute length
        smape = 2.0 * F.broadcast_div(
            F.abs(future_target - forecast), denominator
        )
        return smape

    def mape_loss(self, F, forecast: Tensor, future_target: Tensor) -> Tensor:
        r"""
        .. math::

            mape = (100/H)*mean(|Y - Y_hat| / |Y|)

        According to paper: https://arxiv.org/abs/1905.10437.
        """

        denominator = F.abs(future_target).clip(self.loss_denom_eps, np.inf)
        mape = F.broadcast_div(F.abs(future_target - forecast), denominator)
        return mape

    def mase_loss(
        self,
        F,
        forecast: Tensor,
        future_target: Tensor,
        past_target: Tensor,
        periodicity: int,
    ) -> Tensor:
        r"""
        .. math::

            mase = (1/H)*(mean(|Y - Y_hat|) / seasonal_error)

        According to paper: https://arxiv.org/abs/1905.10437.
        """

        whole_target = F.concat(past_target, future_target, dim=1)
        seasonal_error = self.compute_mase_denominator(
            F, whole_target, periodicity
        )
        mase = F.broadcast_div(
            F.abs(future_target - forecast), seasonal_error
        ).expand_dims(-1)
        return mase

    def compute_mase_denominator(
        self,
        F,
        target: Tensor,
        periodicity: int,
    ):

        # factor = 1 / (F.sum(observed, axis=1) - periodicity)
        seasonal_error = F.mean(
            F.abs(
                F.slice_axis(target, axis=1, begin=periodicity, end=None)
                - F.slice_axis(target, axis=1, begin=0, end=-periodicity)
            ),
            axis=1,
        )
        return seasonal_error.clip(self.loss_denom_eps, np.inf)

    def build_out_proj(self, F, w, b):
        if not self.is_adaptive:
            return

        def out_proj(x):
            return F.linalg.gemm2(x, w) + b.expand_dims(-1)

        self.out_proj = out_proj


class MetaARDetTrainingNetwork(MetaARDetNetwork):
    @validated()
    def __init__(self, alpha: float = 0, beta: float = 0, **kwargs) -> None:
        super().__init__(**kwargs)

        # regularization weights
        self.alpha = alpha
        self.beta = beta

        if alpha:
            self.ar_loss = ActivationRegularizationLoss(
                alpha, time_axis=1, batch_axis=0
            )
        if beta:
            self.tar_loss = TemporalActivationRegularizationLoss(
                beta, time_axis=1, batch_axis=0
            )

    # noinspection PyMethodOverriding,PyPep8Naming
    def hybrid_forward(
        self,
        F,
        feat_static_cat: Tensor,
        feat_static_real: Tensor,
        past_time_feat: Tensor,
        past_target: Tensor,
        past_observed_values: Tensor,
        future_time_feat: Tensor,
        future_target: Tensor,
        future_observed_values: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Computes the loss for training MetaARDet, all inputs tensors representing
        time series have NTC layout.

        Parameters
        ----------
        F
        feat_static_cat : (batch_size, num_features)
        feat_static_real : (batch_size, num_features)
        past_time_feat : (batch_size, history_length, num_features)
        past_target : (batch_size, history_length, *target_shape)
        past_observed_values : (batch_size, history_length, *target_shape, seq_len)
        future_time_feat : (batch_size, prediction_length, num_features)
        future_target : (batch_size, prediction_length, *target_shape)
        future_observed_values : (batch_size, prediction_length, *target_shape)

        Returns loss with shape (batch_size, context + prediction_length, 1)
        -------

        """

        target = future_target

        pred, scale, rnn_outputs = self.predictions(
            feat_static_cat=feat_static_cat,
            feat_static_real=feat_static_real,
            past_time_feat=past_time_feat,
            past_target=past_target,
            past_observed_values=past_observed_values,
            future_time_feat=future_time_feat,
            future_target=future_target,
        )

        pred = F.broadcast_div(pred, scale)
        target = F.broadcast_div(target, scale)
        past_target = F.broadcast_div(past_target, scale)

        # (batch_size, seq_len)
        if self.loss == "MAPE":
            loss = self.mape_loss(F, pred, target)
        elif self.loss == "MASE":
            loss = self.mase_loss(
                F, pred, target, past_target, self.periodicity
            )
        elif self.loss == "sMAPE":
            loss = self.smape_loss(F, pred, target)
        elif self.loss == "MSE":
            loss = self.mse_loss(F, pred, target)
        elif self.loss == "MAE":
            loss = self.mae_loss(F, pred, target)
        else:
            raise NotImplementedError(
                "loss string {} not valid".format(self.loss)
            )

        observed_values = future_observed_values

        loss_weights = (
            observed_values
            if (len(self.target_shape) == 0)
            else observed_values.min(axis=-1, keepdims=False)
        )

        weighted_loss = weighted_average(
            F=F, x=loss, weights=loss_weights, axis=1
        )

        # need to mask possible nans and -inf
        loss = F.where(
            condition=loss_weights, x=loss, y=F.zeros_like(loss)
        ).mean(axis=-1)

        # rnn_outputs is already merged into a single tensor
        assert not isinstance(rnn_outputs, list)
        # it seems that the trainer only uses the first return value for backward
        # so we only add regularization to weighted_loss
        if self.alpha:
            ar_loss = self.ar_loss(rnn_outputs)
            weighted_loss = weighted_loss + ar_loss
        if self.beta:
            tar_loss = self.tar_loss(rnn_outputs)
            weighted_loss = weighted_loss + tar_loss

        return weighted_loss, loss


class MetaARDetPredictionNetwork(MetaARDetNetwork):
    @validated()
    def __init__(
        self,
        num_parallel_samples: int = 100,
        solver_mode: str = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.num_parallel_samples = num_parallel_samples

        self.solver_mode = (
            solver_mode if solver_mode is not None else "inverse"
        )

    # noinspection PyMethodOverriding,PyPep8Naming
    def hybrid_forward(
        self,
        F,
        feat_static_cat: Tensor,  # (batch_size, num_features)
        feat_static_real: Tensor,  # (batch_size, num_features)
        past_time_feat: Tensor,  # (batch_size, history_length, num_features)
        past_target: Tensor,  # (batch_size, history_length, *target_shape)
        past_observed_values: Tensor,  # (batch_size, history_length, *target_shape)
        future_time_feat: Tensor,  # (batch_size, prediction_length, num_features)
    ) -> Tensor:
        """
        Predicts samples, all tensors should have NTC layout.
        Parameters
        ----------
        F
        feat_static_cat : (batch_size, num_features)
        feat_static_real : (batch_size, num_features)
        past_time_feat : (batch_size, history_length, num_features)
        past_target : (batch_size, history_length, *target_shape)
        past_observed_values : (batch_size, history_length, *target_shape)
        future_time_feat : (batch_size, prediction_length, num_features)

        Returns
        -------
        Tensor
            Predicted samples
        """

        pred, scale, rnn_outputs = self.predictions(
            feat_static_cat=feat_static_cat,
            feat_static_real=feat_static_real,
            past_time_feat=past_time_feat,
            past_target=past_target,
            past_observed_values=past_observed_values,
            future_time_feat=future_time_feat,
            future_target=None,
        )

        return pred.reshape(
            shape=(-1, 1, self.prediction_length) + self.target_shape
        )
