# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import copy
import logging
import os
from itertools import product
from pathlib import Path
from typing import Iterator, List, Optional
import multiprocessing as mp

import mxnet as mx
import numpy as np
from pydantic import ValidationError

from gluonts.core import fqname_for
from gluonts.core.component import from_hyperparameters, validated
from gluonts.core.exception import GluonTSHyperparametersError
from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import DataBatch
from gluonts.model.estimator import Estimator
from gluonts.model.forecast import Forecast, SampleForecast
from gluonts.model.predictor import Predictor
from gluonts.mx.model.predictor import RepresentableBlockPredictor
from gluonts.mx.trainer import Trainer

from ._estimator import NBEATSEstimator
from ._network import VALID_LOSS_FUNCTIONS

# None is also a valid parameter
from ..ensemble import EnsembleEstimator, EnsemblePredictor


class NBEATSEnsembleEstimator(EnsembleEstimator):
    """
    An ensemble N-BEATS Estimator (approximately) as described
    in the paper:  https://arxiv.org/abs/1905.10437.

    The three meta parameters 'meta_context_length', 'meta_loss_function' and 'meta_bagging_size'
    together define the way the sub-models are assembled together.
    The total number of models used for the ensemble is::

        |meta_context_length| x |meta_loss_function| x meta_bagging_size

    Noteworthy differences in this implementation compared to the paper:
    * The parameter L_H is not implemented; we sample training sequences
    using the default method in GluonTS using the "InstanceSplitter".

    Parameters
    ----------
    freq
        Time time granularity of the data
    prediction_length
        Length of the prediction. Also known as 'horizon'.
    meta_context_length
        The different 'context_length' (aslso known as 'lookback period')
        to use for training the models.
        The 'context_length' is the number of time units that condition the predictions.
        Default and recommended value: [multiplier * prediction_length for multiplier in range(2, 7)]
    meta_loss_function
        The different 'loss_function' (also known as metric) to use for training the models.
        Unlike other models in GluonTS this network does not use a distribution.
        Default and recommended value: ["sMAPE", "MASE", "MAPE"]
    meta_bagging_size
        The number of models that share the parameter combination of 'context_length'
        and 'loss_function'. Each of these models gets a different initialization random initialization.
        Default and recommended value: 10
    trainer
        Trainer object to be used (default: Trainer())
    num_stacks:
        The number of stacks the network should contain.
        Default and recommended value for generic mode: 30
        Recommended value for interpretable mode: 2
    num_blocks
        The number of blocks per stack.
        A list of ints of length 1 or 'num_stacks'.
        Default and recommended value for generic mode: [1]
        Recommended value for interpretable mode: [3]
    block_layers
        Number of fully connected layers with ReLu activation per block.
        A list of ints of length 1 or 'num_stacks'.
        Default and recommended value for generic mode: [4]
        Recommended value for interpretable mode: [4]
    widths
        Widths of the fully connected layers with ReLu activation in the blocks.
        A list of ints of length 1 or 'num_stacks'.
        Default and recommended value for generic mode: [512]
        Recommended value for interpretable mode: [256, 2048]
    sharing
        Whether the weights are shared with the other blocks per stack.
        A list of ints of length 1 or 'num_stacks'.
        Default and recommended value for generic mode: [False]
        Recommended value for interpretable mode: [True]
    expansion_coefficient_lengths
        If the type is "G" (generic), then the length of the expansion coefficient.
        If type is "T" (trend), then it corresponds to the degree of the polynomial.
        If the type is "S" (seasonal) then its not used.
        A list of ints of length 1 or 'num_stacks'.
        Default value for generic mode: [32]
        Recommended value for interpretable mode: [3]
    stack_types
        One of the following values: "G" (generic), "S" (seasonal) or "T" (trend).
        A list of strings of length 1 or 'num_stacks'.
        Default and recommended value for generic mode: ["G"]
        Recommended value for interpretable mode: ["T","S"]
    **kwargs
        Arguments passed down to the individual estimators.
    """

    # The validated() decorator makes sure that parameters are checked by
    # Pydantic and allows to serialize/print models. Note that all parameters
    # have defaults except for `freq` and `prediction_length`. which is
    # recommended in GluonTS to allow to compare models easily.
    @validated()
    def __init__(
        self,
        freq: str,
        prediction_length: int,
        meta_context_length: Optional[List[int]] = None,
        meta_loss_function: Optional[List[str]] = None,
        meta_bagging_size: int = 10,
        trainer: Trainer = Trainer(),
        num_stacks: int = 30,
        widths: Optional[List[int]] = None,
        num_blocks: Optional[List[int]] = None,
        num_block_layers: Optional[List[int]] = None,
        expansion_coefficient_lengths: Optional[List[int]] = None,
        sharing: Optional[List[bool]] = None,
        stack_types: Optional[List[str]] = None,
        num_parallel_training: int = mp.cpu_count(),
        **kwargs,
    ) -> None:
        assert (
            prediction_length > 0
        ), "The value of `prediction_length` should be > 0"

        self.freq = freq
        self.prediction_length = prediction_length

        assert meta_loss_function is None or all(
            [
                loss_function in VALID_LOSS_FUNCTIONS
                for loss_function in meta_loss_function
            ]
        ), f"Each loss function has to be one of the following: {VALID_LOSS_FUNCTIONS}."
        assert meta_context_length is None or all(
            [context_length > 0 for context_length in meta_context_length]
        ), "The value of each `context_length` should be > 0"
        assert (
            meta_bagging_size is None or meta_bagging_size > 0
        ), "The value of each `context_length` should be > 0"

        self.meta_context_length = (
            meta_context_length
            if meta_context_length is not None
            else [multiplier * prediction_length for multiplier in range(2, 8)]
        )
        self.meta_loss_function = (
            meta_loss_function
            if meta_loss_function is not None
            else VALID_LOSS_FUNCTIONS
        )
        self.meta_bagging_size = meta_bagging_size

        # The following arguments are validated in the NBEATSEstimator:
        self.trainer = trainer
        print(f"TRAINER:{str(trainer)}")
        self.num_stacks = num_stacks
        self.widths = widths
        self.num_blocks = num_blocks
        self.num_block_layers = num_block_layers
        self.expansion_coefficient_lengths = expansion_coefficient_lengths
        self.sharing = sharing
        self.stack_types = stack_types

        # Actually instantiate the different models
        estimators = self._estimator_factory(**kwargs)
        super().__init__(
            estimators=estimators, num_parallel_training=num_parallel_training
        )

    def _estimator_factory(self, **kwargs):
        estimators = []
        for context_length, loss_function, init_id in product(
            self.meta_context_length,
            self.meta_loss_function,
            list(range(self.meta_bagging_size)),
        ):
            # So far no use for the init_id, models are by default always randomly initialized
            estimators.append(
                NBEATSEstimator(
                    freq=self.freq,
                    prediction_length=self.prediction_length,
                    context_length=context_length,
                    trainer=copy.deepcopy(self.trainer),
                    num_stacks=self.num_stacks,
                    widths=self.widths,
                    num_blocks=self.num_blocks,
                    num_block_layers=self.num_block_layers,
                    expansion_coefficient_lengths=self.expansion_coefficient_lengths,
                    sharing=self.sharing,
                    stack_types=self.stack_types,
                    loss_function=loss_function,
                    **kwargs,
                )
            )
        return estimators

    @classmethod
    def from_hyperparameters(
        cls, **hyperparameters
    ) -> "NBEATSEnsembleEstimator":
        Model = getattr(cls.__init__, "Model", None)

        if not Model:
            raise AttributeError(
                f"Cannot find attribute Model attached to the "
                f"{fqname_for(cls)}. Most probably you have forgotten to mark "
                f"the class constructor as @validated()."
            )

        try:
            trainer = from_hyperparameters(Trainer, **hyperparameters)
            return cls(
                **Model(**{**hyperparameters, "trainer": trainer}).__dict__
            )
        except ValidationError as e:
            raise GluonTSHyperparametersError from e


class NBEATSEnsemblePredictor(EnsemblePredictor):
    _logged_warning = False

    def predict(
        self, dataset: Dataset, num_samples: int = 100, **kwargs
    ) -> Iterator[Forecast]:
        if num_samples != len(self.predictors) and not self._logged_warning:
            logger = logging.getLogger(__name__)
            logger.warning(
                "NBEATS is not using samples. Using one 'sample' per predictor."
            )
            self._logged_warning = True
        return super().predict(dataset, num_samples=len(self.predictors))
