# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from gluonts.core.component import validated
from gluonts.torch.distributions import (
    DistributionOutput,
    StudentTOutput,
)
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood

from gluonts.torch.model.fedformer.FEDformer import FEDformer


class FEDformerModel(nn.Module):
    """
    Module implementing the FEDformer model, see [SFG17]_.

    *Note:* the code of this model is unrelated to the implementation behind
    `SageMaker's FEDformer Forecasting Algorithm
    <https://docs.aws.amazon.com/sagemaker/latest/dg/deepar.html>`_.

    Parameters
    ----------
    freq
        String indicating the sampling frequency of the data to be processed.
    context_length
        Length of the RNN unrolling prior to the forecast date.
    prediction_length
        Number of time points to predict.
    num_feat_dynamic_real
        Number of dynamic real features that will be provided to ``forward``.
    num_feat_static_real
        Number of static real features that will be provided to ``forward``.
    num_feat_static_cat
        Number of static categorical features that will be provided to
        ``forward``.
    cardinality
        List of cardinalities, one for each static categorical feature.
    embedding_dimension
        Dimension of the embedding space, one for each static categorical
        feature.
    n_block
        Number of layers in the RNN.
    hidden_size
        Size of the hidden layers in the RNN.
    dropout_rate
        Dropout rate to be applied at training time.
    distr_output
        Type of distribution to be output by the model at each time step
    scaling
        Whether to apply mean scaling to the observations (target).
    """

    @validated()
    def __init__(
        self,
        freq: str,
        context_length: int,
        prediction_length: int,
        num_feat_dynamic_real: int,
        num_future_feat: int,
        num_feat_static_real: int,
        num_feat_static_cat: int,
        cardinality: List[int],
        embedding_dimension: Optional[List[int]] = None,
        n_block: int = 2,
        hidden_size: int = 128,
        dropout_rate: float = 0.1,
        n_head: int = 8,
        distr_output: DistributionOutput = StudentTOutput(),
        scaling: bool = True,
    ) -> None:
        super().__init__()

        assert distr_output.event_shape == ()

        self.context_length = context_length
        self.prediction_length = prediction_length
        self.num_feat_dynamic_real = num_feat_dynamic_real
        self.num_future_feat = num_future_feat
        self.num_feat_static_cat = num_feat_static_cat
        self.num_feat_static_real = num_feat_static_real
        self.embedding_dimension = (
            embedding_dimension
            if embedding_dimension is not None or cardinality is None
            else [min(32, (cat + 1) // 2) for cat in cardinality]
        )
        self.past_length = self.context_length
        self.embedder = FeatureEmbedder(
            cardinalities=cardinality,
            embedding_dims=self.embedding_dimension,
        )

        if scaling:
            self.scaler = MeanScaler(dim=-1, keepdim=True)
        else:
            self.scaler = NOPScaler(dim=-1, keepdim=True)

        feature_size = self.num_feat_dynamic_real - self.num_future_feat + 1

        self.distr_output = distr_output
        self.args_proj = distr_output.get_args_proj(feature_size)

        self.fedformer_encoder = FEDformer(
            freq=freq.lower(),
            seq_len=context_length,
            label_len=context_length // 2,
            pred_len=prediction_length,
            enc_in=feature_size,
            dec_in=feature_size,
            c_out=feature_size,
            d_model=hidden_size,
            n_heads=n_head,
            d_ff=hidden_size,
            dropout=dropout_rate,
            e_layers=n_block,
            activation="relu",
        )

    @property
    def _past_length(self) -> int:
        return self.context_length

    def input_shapes(self, batch_size=1) -> Dict[str, Tuple[int, ...]]:
        return {
            "feat_static_cat": (batch_size, self.num_feat_static_cat),
            "feat_static_real": (batch_size, self.num_feat_static_real),
            "past_time_feat": (
                batch_size,
                self._past_length,
                self.num_feat_dynamic_real,
            ),
            "past_target": (batch_size, self._past_length),
            "past_observed_values": (batch_size, self._past_length),
            "future_time_feat": (
                batch_size,
                self.prediction_length,
                self.num_feat_dynamic_real,
            ),
        }

    def input_types(self) -> Dict[str, torch.dtype]:
        return {
            "feat_static_cat": torch.long,
            "feat_static_real": torch.float,
            "past_time_feat": torch.float,
            "past_target": torch.float,
            "past_observed_values": torch.float,
            "future_time_feat": torch.float,
        }

    def forward(
        self,
        feat_static_cat: torch.Tensor,
        feat_static_real: torch.Tensor,
        past_time_feat: torch.Tensor,
        past_target: torch.Tensor,
        past_observed_values: torch.Tensor,
        future_time_feat: torch.Tensor,
    ) -> torch.Tensor:
        """
        Invokes the model on input data, and produce outputs future samples.

        Parameters
        ----------
        feat_static_cat
            Tensor of static categorical features,
            shape: ``(batch_size, num_feat_static_cat)``.
        feat_static_real
            Tensor of static real features,
            shape: ``(batch_size, num_feat_static_real)``.
        past_time_feat
            Tensor of dynamic real features in the past,
            shape: ``(batch_size, past_length, num_feat_dynamic_real)``.
        past_target
            Tensor of past target values,
            shape: ``(batch_size, past_length)``.
        past_observed_values
            Tensor of observed values indicators,
            shape: ``(batch_size, past_length)``.
        future_time_feat
            (Optional) tensor of dynamic real features in the past,
            shape: ``(batch_size, prediction_length, num_feat_dynamic_real)``.
        """
        x_mark_enc = past_time_feat[:, :, : self.num_future_feat]
        x_mark_dec = torch.cat(
            (
                past_time_feat[
                    :, -self.context_length // 2 + 1 :, : self.num_future_feat
                ],
                future_time_feat[:, :, : self.num_future_feat],
            ),
            dim=1,
        )

        past_feature = past_time_feat[:, :, self.num_future_feat :]

        _, scale = self.scaler(past_target, past_observed_values)

        scaled_past_target = past_target / scale

        embedded_cat = self.embedder(feat_static_cat)
        static_feat = torch.cat(
            (embedded_cat, feat_static_real, scale.log()),
            dim=-1,
        )
        past_feature = torch.cat(
            (scaled_past_target.unsqueeze(-1), past_feature), dim=-1
        )
        output = self.fedformer_encoder(
            x_enc=past_feature,
            x_mark_enc=x_mark_enc,
            x_mark_dec=x_mark_dec,
        )
        distr_args = self.args_proj(output)
        return distr_args, torch.zeros_like(scale), scale
