# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import multiprocessing as mp
import tempfile
from pathlib import Path

from gluonts.core import serde
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.model.estimator import Estimator
from gluonts.model.forecast import Forecast, SampleForecast
from gluonts.model.predictor import Predictor
from typing import List, Iterator, cast, Sequence, Optional, Dict, Union
import numpy as np
import mxnet as mx
from gluonts.mx.context import check_gpu_support


class EnsemblePredictor(Predictor):
    def __init__(self, predictors: List[Predictor]) -> None:
        assert predictors
        self.num_predictors = len(predictors)
        self.predictors = predictors
        p0 = predictors[0]
        super().__init__(
            prediction_length=p0.prediction_length,
            freq=p0.freq,
            lead_time=p0.lead_time,
        )
        for p in predictors:
            assert p.prediction_length == self.prediction_length
            assert p.freq == self.freq
            assert p.lead_time == self.lead_time

    def predict(
        self, dataset: Dataset, num_samples: int = 100, **kwargs
    ) -> Iterator[Forecast]:
        assert (
            num_samples >= self.num_predictors
        ), "Cannot use less samples than predictors"
        samples_per_predictor = [
            num_samples // self.num_predictors
        ] * self.num_predictors
        diff = num_samples - sum(samples_per_predictor)
        for i in range(len(samples_per_predictor)):
            samples_per_predictor[i] += 1
            diff -= 1
            if diff <= 0:
                break

        def combine_fcst(forecasts: Sequence[Forecast]) -> Forecast:
            f0 = forecasts[0]
            for f in forecasts:
                assert isinstance(f, SampleForecast)
                assert f.start_date == f0.start_date
                assert f.freq == f0.freq
                assert f.item_id == f0.item_id
                assert f.samples.shape[1] == f0.prediction_length
                assert f.prediction_length == f0.prediction_length
                # (num_samples, prediction_length)
            forecasts = [cast(SampleForecast, f) for f in forecasts]
            samples = np.concatenate([f.samples for f in forecasts], axis=0)
            assert samples.shape[1] == f0.prediction_length
            return SampleForecast(
                samples,
                start_date=f0.start_date,
                freq=f0.freq,
                item_id=f0.item_id,
            )

        fcst_it = [
            p.predict(dataset, num_samples=samples_per_predictor[i])
            for i, p in enumerate(self.predictors)
        ]
        for forecasts in zip(*fcst_it):
            yield combine_fcst(forecasts)

    def serialize(self, path: Path) -> None:
        super().serialize(path)
        for i, pred in enumerate(self.predictors):
            bp = path / f"predictor_{i}"
            bp.mkdir()
            pred.serialize(bp)

    @classmethod
    def deserialize(cls, path: Path) -> "EnsemblePredictor":
        bps = list(path.glob("predictor_*"))
        bps.sort(key=lambda p: int(str(p).split("_")[-1]))
        predictors = [Predictor.deserialize(bp) for bp in bps]
        return EnsemblePredictor(predictors)

    def __eq__(self, other: Predictor):
        if type(self) != type(other):
            return False
        if self.prediction_length != other.prediction_length:
            return False
        if self.freq != other.freq:
            return False
        other = cast("EnsemblePredictor", other)
        if len(self.predictors) != len(other.predictors):
            return False
        return all(p == q for p, q in zip(self.predictors, other.predictors))


def _train_single_model(args):
    i, estimator_json, train_data, valid_data, seed = args
    logger = logging.getLogger(__name__)
    estimator = serde.load_json(estimator_json)
    np.random.seed(seed)
    mx.random.seed(seed)
    logger.info(f"Training estimator {i + 1} in ensemble.")
    pred = estimator.train(train_data, valid_data, num_workers=0)
    pred_path = Path(tempfile.mkdtemp())
    pred.serialize(pred_path)

    return dict(
        pred_name=f"predictor-{i}",
        predictor_path=pred_path,
    )


def parallel_train_models(
    estimators: List[Estimator],
    train_data: Dataset,
    valid_data: Optional[Dataset] = None,
    num_workers: int = mp.cpu_count(),
    seeds: Optional[List[int]] = None,
) -> List[Predictor]:
    """
    Train the estimators in parallel
    """
    estimator_jsons = [serde.dump_json(estim) for estim in estimators]
    n_models = len(estimator_jsons)

    predictors = []
    seeds = seeds if seeds else list(range(n_models))

    with mp.Pool(num_workers) as pool:
        inputs = [
            (i, estimator_jsons[i], train_data, valid_data, seeds[i])
            for i in range(n_models)
        ]
        for res in pool.map(_train_single_model, inputs):
            predictors.append(
                Predictor.deserialize(Path(res["predictor_path"]))
            )
    return predictors


class EnsembleEstimator(Estimator):
    @validated()
    def __init__(
        self,
        base_estimator: Optional[Estimator] = None,
        variants: Union[None, int, Dict[str, List]] = None,
        estimators: Optional[List[Estimator]] = None,
        num_parallel_training: int = mp.cpu_count(),
    ):
        """
        Create an ensemble of models. On CPU the training of the models is parallelized.
        There are two ways for creating the ensemble.
        Either provide a list of estimators:

            EnsembleEstimator(estimator=[my_estimator1, my_estimator2, ...])

        or provide a base estimator and variants. variants can be an int or a dictionary
        with lists of parameter values.

            EnsembleEstimator(
              base_estimator=DeepAREstimator(...),
              variants={"context_length": [50, 100]}
            )
        """

        if base_estimator is not None:
            assert variants
            assert estimators is None
            if isinstance(variants, int):
                assert variants > 1
                overwrites = [{}] * variants
            else:
                l = None
                for k, v in variants.items():
                    if l is None:
                        l = len(v)
                    assert (
                        len(v) == l
                    ), "All values in variants have to have the same length"
                assert l > 1, "Please provide two or more variants"
                overwrites = [
                    {k: v[i] for k, v in variants.items()} for i in range(l)
                ]
            self.estimators = [
                serde.flat.clone(base_estimator, ov) if ov else base_estimator
                for ov in overwrites
            ]
        else:
            assert estimators is not None
            assert variants is None
            self.estimators = estimators
        e0 = self.estimators[0]
        self.freq = e0.freq
        self.prediction_length = e0.prediction_length
        super().__init__(lead_time=e0.lead_time)
        for estim in self.estimators:
            assert estim.freq == e0.freq
            assert estim.lead_time == e0.lead_time
            assert estim.prediction_length == e0.prediction_length
        self.num_parallel_training = num_parallel_training

    def train(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        callback=None,
    ) -> EnsemblePredictor:
        logger = logging.getLogger(__name__)
        logger.info(f"Training ensemble of {len(self.estimators)} models.")
        if check_gpu_support() or self.num_parallel_training <= 1:
            if check_gpu_support():
                logger.info("GPU found. Training models sequentially")
            else:
                logger.info("Training models sequentially")
            predictors = []
            for i, e in enumerate(self.estimators):
                logger.info(f"Training estimator {i + 1} in ensemble.")
                predictors.append(e.train(training_data, validation_data))
        else:
            logger.info(
                f"Running on CPU device. Training {self.num_parallel_training} models in parallel"
            )
            predictors = parallel_train_models(
                self.estimators,
                train_data=training_data,
                valid_data=validation_data,
                num_workers=self.num_parallel_training,
            )
        ensemble_pred = EnsemblePredictor(predictors)
        return ensemble_pred
