# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# TODO: Fix typing
# type: ignore

from typing import List, Optional

# Standard library imports
import numpy as np

# Third-party imports
from mxnet.gluon import HybridBlock

# First-party imports
from gluonts.core.component import DType, validated
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.stat import calculate_dataset_statistics
from gluonts.mx.model.estimator import GluonEstimator
from gluonts.model.predictor import Predictor
from gluonts.mx.model.predictor import RepresentableBlockPredictor

from gluonts.mx.trainer import Trainer
from gluonts.mx.util import copy_parameters
from gluonts.time_feature import (
    TimeFeature,
    get_lags_for_frequency,
    time_features_from_frequency_str,
)
from gluonts.transform import (
    AddAgeFeature,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AsNumpyArray,
    Chain,
    ExpectedNumInstanceSampler,
    InstanceSplitter,
    RemoveFields,
    SetField,
    Transformation,
    VstackFeatures,
    InstanceSampler,
)
from gluonts.transform.feature import (
    DummyValueImputation,
    MissingValueImputation,
    AddConstFeature,
)

# Relative imports
from ._network import (
    MetaARDetPredictionNetwork,
    MetaARDetTrainingNetwork,
)
from ...meta_tools import change_transform_for_predictor


class MetaARDetEstimator(GluonEstimator):
    """
    Construct a MetaARDet estimator. Returns point forecasts.

    Instead of optimizing the linear layer as usual, it computes it in closed form solving a ridge regression
    problem over the observations in the context window (past_targets).

    The class contains several flags used for ablation.

    Parameters
    ----------
    freq
        Frequency of the data to train on and predict
    prediction_length
        Length of the prediction horizon
    trainer
        Trainer object to be used (default: Trainer())
    context_length
        Number of steps to unroll the RNN for before computing predictions
        (default: None, in which case context_length = prediction_length)
    num_layers
        Number of RNN layers (default: 2)
    num_cells
        Number of RNN cells for each layer (default: 40)
    output_dim
        Dimension of the representation output (default: 40)
    cell_type
        Type of recurrent cells to use (available: 'lstm' or 'gru';
        default: 'lstm')
    dropoutcell_type
        Type of dropout cells to use
        (available: 'ZoneoutCell', 'RNNZoneoutCell', 'VariationalDropoutCell' or 'VariationalZoneoutCell';
        default: 'ZoneoutCell')
    dropout_rate
        Dropout regularization parameter (default: 0.1)
    use_feat_dynamic_real
        Whether to use the ``feat_dynamic_real`` field from the data
        (default: False)
    use_feat_static_cat
        Whether to use the ``feat_static_cat`` field from the data
        (default: False)
    use_feat_static_real
        Whether to use the ``feat_static_real`` field from the data
        (default: False)
    cardinality
        Number of values of each categorical feature.
        This must be set if ``use_feat_static_cat == True`` (default: None)
    embedding_dimension
        Dimension of the embeddings for categorical features
        (default: [min(50, (cat+1)//2) for cat in cardinality])
    scaling
        Whether to automatically scale the target values (default: true)
    lags_seq
        Indices of the lagged target values to use as inputs of the RNN
        (default: None, in which case these are automatically determined
        based on freq)
    time_features
        Time features to use as inputs of the RNN (default: None, in which
        case these are automatically determined based on freq)
    num_parallel_samples
        Number of evaluation samples per time series to increase parallelism during inference.
        This is a model optimization that does not affect the accuracy (default: 100)
    imputation_method
        One of the methods from ImputationStrategy
    alpha
        The scaling coefficient of the activation regularization
    beta
        The scaling coefficient of the temporal activation regularization
    """

    @validated()
    def __init__(
        self,
        # DeepAR parameters
        freq: str,
        prediction_length: int,
        train_prediction_length: Optional[int],
        trainer: Trainer = Trainer(),
        context_length: Optional[int] = None,
        net_context_length: Optional[int] = None,
        num_layers: int = 2,
        num_cells: int = 40,
        output_dim: int = 40,
        cell_type: str = "lstm",
        dropoutcell_type: str = "ZoneoutCell",
        dropout_rate: float = 0.1,
        use_feat_dynamic_real: bool = False,
        use_feat_static_cat: bool = False,
        use_feat_static_real: bool = False,
        use_time_features: bool = True,
        cardinality: Optional[List[int]] = None,
        embedding_dimension: Optional[List[int]] = None,
        loss: str = "sMAPE",  # sMAPE, MASE, MAPE
        scaling: bool = True,
        lags_seq: Optional[List[int]] = None,
        time_features: Optional[List[TimeFeature]] = None,
        num_parallel_samples: int = 100,
        imputation_method: Optional[MissingValueImputation] = None,
        dtype: DType = np.float32,
        alpha: float = 0.0,
        beta: float = 0.0,
        train_sampler: InstanceSampler = ExpectedNumInstanceSampler(1.0),
        # specific parameters
        test_solver_mode: Optional[str] = None,
        test_context_length: Optional[int] = None,
        test_batch_size: Optional[int] = None,
        biased_regularization: Optional[bool] = False,
        weighted_least_squares: Optional[str] = None,  # "MAPE", "MASE"
        batch_size: int = 32,
        is_adaptive_training: bool = True,
        is_adaptive_prediction: bool = True,
        iterated_forecasts_during_training: bool = True,
        use_shared_linear: bool = True,
        encoder_type: str = "RNN",  # "linear", "identity", "feedforward"
    ) -> None:
        super().__init__(trainer=trainer, batch_size=batch_size, dtype=dtype)

        assert (
            prediction_length > 0
        ), "The value of `prediction_length` should be > 0"
        assert (
            context_length is None or context_length > 0
        ), "The value of `context_length` should be > 0"
        assert num_layers > 0, "The value of `num_layers` should be > 0"
        assert num_cells > 0, "The value of `num_cells` should be > 0"
        assert output_dim > 0, "The value of `output_dim` should be > 0"
        supported_dropoutcell_types = [
            "ZoneoutCell",
            "RNNZoneoutCell",
            "VariationalDropoutCell",
            "VariationalZoneoutCell",
        ]
        assert (
            dropoutcell_type in supported_dropoutcell_types
        ), f"`dropoutcell_type` should be one of {supported_dropoutcell_types}"
        assert dropout_rate >= 0, "The value of `dropout_rate` should be >= 0"
        assert (cardinality and use_feat_static_cat) or (
            not (cardinality or use_feat_static_cat)
        ), "You should set `cardinality` if and only if `use_feat_static_cat=True`"
        assert cardinality is None or all(
            [c > 0 for c in cardinality]
        ), "Elements of `cardinality` should be > 0"
        assert embedding_dimension is None or all(
            [e > 0 for e in embedding_dimension]
        ), "Elements of `embedding_dimension` should be > 0"
        assert (
            num_parallel_samples > 0
        ), "The value of `num_parallel_samples` should be > 0"
        assert alpha >= 0, "The value of `alpha` should be >= 0"
        assert beta >= 0, "The value of `beta` should be >= 0"

        self.freq = freq
        self.context_length = (
            context_length if context_length is not None else prediction_length
        )
        self.net_context_length = (
            net_context_length
            if net_context_length is not None
            else self.context_length
        )
        self.prediction_length = prediction_length
        self.train_prediction_length = (
            train_prediction_length
            if train_prediction_length is not None
            else prediction_length
        )
        self.num_layers = num_layers
        self.num_cells = num_cells
        self.output_dim = output_dim
        self.cell_type = cell_type
        self.dropoutcell_type = dropoutcell_type
        self.dropout_rate = dropout_rate
        self.use_feat_dynamic_real = use_feat_dynamic_real
        self.use_feat_static_cat = use_feat_static_cat
        self.use_feat_static_real = use_feat_static_real
        self.use_time_features = use_time_features
        self.cardinality = (
            cardinality if cardinality and use_feat_static_cat else [1]
        )
        self.embedding_dimension = (
            embedding_dimension
            if embedding_dimension is not None
            else [min(50, (cat + 1) // 2) for cat in self.cardinality]
        )
        self.scaling = scaling
        self.lags_seq = (
            lags_seq
            if lags_seq is not None
            else get_lags_for_frequency(freq_str=freq)
        )
        self.time_features = (
            time_features
            if time_features is not None
            else time_features_from_frequency_str(self.freq)
        )

        self.history_length = self.net_context_length + max(self.lags_seq)

        self.num_parallel_samples = num_parallel_samples

        # RG: added values
        self.loss = loss
        self.event_shape = ()
        self.value_in_support = 0.0
        self.test_solver_mode = test_solver_mode
        self.test_context_length = test_context_length
        self.test_batch_size = test_batch_size
        self.biased_regularization = biased_regularization
        self.weighted_least_squares = weighted_least_squares
        self.is_adaptive_training = is_adaptive_training
        self.is_adaptive_prediction = is_adaptive_prediction
        self.iterated_forecasts_during_training = (
            iterated_forecasts_during_training
        )
        self.use_shared_linear = use_shared_linear
        self.encoder_type = encoder_type

        self.imputation_method = (
            imputation_method
            if imputation_method is not None
            else DummyValueImputation(self.value_in_support)
        )

        self.train_sampler = train_sampler
        self.alpha = alpha
        self.beta = beta
        self.use_additional_features = use_feat_static_cat and use_feat_dynamic_real and use_feat_static_cat

    @classmethod
    def derive_auto_fields(cls, train_iter):
        stats = calculate_dataset_statistics(train_iter)

        return {
            "use_feat_dynamic_real": stats.num_feat_dynamic_real > 0,
            "use_feat_static_cat": bool(stats.feat_static_cat),
            "cardinality": [len(cats) for cats in stats.feat_static_cat],
        }

    def create_transformation(self) -> Transformation:
        remove_field_names = [FieldName.FEAT_DYNAMIC_CAT, FieldName.FEAT_TIME]
        if not self.use_feat_static_real:
            remove_field_names.append(FieldName.FEAT_STATIC_REAL)
        if not self.use_feat_dynamic_real:
            remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)

        if self.use_time_features:
            add_time_transforms = [
                AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    time_features=self.time_features,
                    pred_length=self.train_prediction_length,
                ),
                AddAgeFeature(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_AGE,
                    pred_length=self.train_prediction_length,
                    log_scale=True,
                    dtype=self.dtype,
                ),
                VstackFeatures(
                    output_field=FieldName.FEAT_TIME,
                    input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
                    + (
                        [FieldName.FEAT_DYNAMIC_REAL]
                        if self.use_feat_dynamic_real
                        else []
                    ),
                ),
            ]
        else:
            add_time_transforms = [
                AddConstFeature(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    pred_length=self.train_prediction_length,
                    const=0.0,
                    dtype=self.dtype,
                ),
            ]

        return Chain(
            [RemoveFields(field_names=remove_field_names)]
            + (
                [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0.0])]
                if not self.use_feat_static_cat
                else []
            )
            + (
                [
                    SetField(
                        output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]
                    )
                ]
                if not self.use_feat_static_real
                else []
            )
            + [
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_CAT,
                    expected_ndim=1,
                    dtype=self.dtype,
                ),
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_REAL,
                    expected_ndim=1,
                    dtype=self.dtype,
                ),
                AsNumpyArray(
                    field=FieldName.TARGET,
                    # in the following line, we add 1 for the time dimension
                    expected_ndim=1 + len(self.event_shape),
                    dtype=self.dtype,
                ),
                AddObservedValuesIndicator(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.OBSERVED_VALUES,
                    dtype=self.dtype,
                    imputation_method=self.imputation_method,
                ),
            ]
            + add_time_transforms
            + [
                InstanceSplitter(
                    target_field=FieldName.TARGET,
                    is_pad_field=FieldName.IS_PAD,
                    start_field=FieldName.START,
                    forecast_start_field=FieldName.FORECAST_START,
                    train_sampler=self.train_sampler,
                    past_length=self.history_length,
                    future_length=self.train_prediction_length,  # + 1 is  hack to avoid empty context during training
                    time_series_fields=[
                        FieldName.FEAT_TIME,
                        FieldName.OBSERVED_VALUES,
                    ],
                    dummy_value=self.value_in_support,
                ),
            ]
        )

    def create_training_network(self) -> MetaARDetTrainingNetwork:
        return MetaARDetTrainingNetwork(
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            output_dim=self.output_dim,
            cell_type=self.cell_type,
            history_length=self.history_length,
            context_length=self.net_context_length,
            prediction_length=self.train_prediction_length,
            loss=self.loss,
            dropoutcell_type=self.dropoutcell_type,
            dropout_rate=self.dropout_rate,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            lags_seq=self.lags_seq,
            scaling=self.scaling,
            dtype=self.dtype,
            alpha=self.alpha,
            beta=self.beta,
            biased_regularization=self.biased_regularization,
            weighted_least_squares=self.weighted_least_squares,
            freq=self.freq,
            is_adaptive=self.is_adaptive_training,
            iterated_forecasts=self.iterated_forecasts_during_training,
            use_shared_linear=self.use_shared_linear,
            encoder_type=self.encoder_type,
            use_additional_features=self.use_additional_features,
            use_time_features=self.use_time_features,
        )

    def create_predictor(
        self, transformation: Transformation, trained_network: HybridBlock
    ) -> Predictor:

        test_context_length = (
            self.context_length
            if self.test_context_length is None
            else self.test_context_length
        )
        test_net_context_length = (
            test_context_length
            if self.net_context_length is None
            else self.net_context_length
        )
        test_history_length = test_context_length + max(self.lags_seq)

        prediction_network = MetaARDetPredictionNetwork(
            num_parallel_samples=self.num_parallel_samples,
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            output_dim=self.output_dim,
            cell_type=self.cell_type,
            history_length=test_history_length,
            context_length=test_net_context_length,
            prediction_length=self.prediction_length,
            loss=self.loss,
            dropoutcell_type=self.dropoutcell_type,
            dropout_rate=self.dropout_rate,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            lags_seq=self.lags_seq,
            scaling=self.scaling,
            dtype=self.dtype,
            solver_mode=self.test_solver_mode,
            biased_regularization=self.biased_regularization,
            weighted_least_squares=self.weighted_least_squares,
            freq=self.freq,
            is_adaptive=self.is_adaptive_prediction,
            iterated_forecasts=True,
            use_shared_linear=self.use_shared_linear,
            encoder_type=self.encoder_type,
            use_additional_features=self.use_additional_features,
            use_time_features=self.use_time_features,
        )

        # in the following, ignore_extra=True allows to train the non-adaptive model and test the adaptive one
        copy_parameters(trained_network, prediction_network, ignore_extra=True)

        transformation = change_transform_for_predictor(
            transformation,
            history_length=test_history_length,
            future_length=self.prediction_length,
            time_features=self.time_features,
        )

        test_batch_size = (
            self.test_batch_size
            if self.test_batch_size is not None
            else self.trainer.batch_size
        )

        return RepresentableBlockPredictor(
            input_transform=transformation,
            prediction_net=prediction_network,
            batch_size=test_batch_size,
            freq=self.freq,
            prediction_length=self.prediction_length,
            ctx=self.trainer.ctx,
            dtype=self.dtype,
        )
