import math
from typing import Iterator, Sequence, Iterable, List, Dict, Tuple, cast

import numpy as np

import mxnet as mx
import pandas as pd

from gluonts import transform
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset, DataEntry
from gluonts.dataset.field_names import FieldName
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.gluonts_tqdm import tqdm
from gluonts.model.forecast import Forecast, SampleForecast, QuantileForecast
from gluonts.mx.model.estimator import GluonEstimator
from gluonts.mx.trainer.model_averaging import AveragingStrategy
from gluonts.support.util import maybe_len
from gluonts.transform import (
    InstanceSampler,
    UniformSplitSampler,
    TransformedDataset,
    InstanceSplitter,
    Chain,
    Transformation,
    AddTimeFeatures,
    AddAgeFeature,
    AddConstFeature,
)


class MinHistorySampler(InstanceSampler):
    @validated()
    def __init__(self, min_history: int, num_instances: float):
        self.min_history = min_history
        self.num_instances = num_instances

        self.total_length = 0
        self.n = 0

    def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
        a = max(a, self.min_history)
        if a >= b:
            return np.array([], dtype=np.int32)
        window_size = b - a + 1

        self.n += 1
        self.total_length += window_size
        avg_length = self.total_length / self.n

        sampler = UniformSplitSampler(self.num_instances / avg_length)
        indices = sampler(ts, a, b)
        if len(indices) > 0:
            i = np.random.randint(0, len(indices))
            return indices[i : i + 1]
        else:
            return indices


# hack classes to use custom parameters inside mx.gluon.HybridBlock
class Reg(mx.gluon.HybridBlock):
    def __init__(self, init_val):
        super().__init__()
        self.reg = self.params.get(
            "reg", init=mx.init.Constant(init_val), shape=(1,)
        )

    def hybrid_forward(self, F, x, reg):
        # print('reg', reg)
        return F.cast(F.clip(reg, a_min=0, a_max=1000000), dtype="float64")


class Bias(mx.gluon.HybridBlock):
    def __init__(self, init_val, dim):
        super().__init__()
        self.bias = self.params.get(
            "bias", init=mx.init.Constant(init_val), shape=(dim)
        )

    def hybrid_forward(self, F, x, bias):
        # print('bias', bias)
        return F.cast(bias, dtype="float64")


class AbstractPredictorWrapper:
    def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
        raise NotImplementedError


class PredictorWrapper(AbstractPredictorWrapper):
    def __init__(self, predictor, new_prediction_length):
        self._wrapped_predictor = predictor
        self.prediction_length = new_prediction_length

    def __getattr__(self, attr):
        if attr in self.__dict__:
            return getattr(self, attr)
        return getattr(self._wrapped_predictor, attr)

    def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
        for forecast in self._wrapped_predictor.predict(dataset, **kwargs):
            yield forecast_slice(forecast, list(range(self.prediction_length)))


def forecast_slice(forecast: Forecast, indices):
    if isinstance(forecast, SampleForecast):
        return SampleForecast(
            forecast.samples[:, indices],
            forecast.mean_ts.index[indices[0]],
            forecast.freq,
            forecast.item_id,
            forecast.info,
        )
    elif isinstance(forecast, QuantileForecast):
        return QuantileForecast(
            forecast.forecast_array[:, indices],
            forecast.start_date,
            forecast.freq,
            forecast.forecast_keys,
            forecast.item_id,
            forecast.info,
        )
    else:
        raise NotImplementedError(
            "forecast of type",
            type(forecast),
            "cannot be sliced for lack of implementation",
        )


def count_model_params(net) -> int:
    params = net.collect_params()
    num_params = 0
    for p in params:
        v = params[p]
        num_params += np.prod(v.shape)
    return num_params


def change_transformation(estimator: GluonEstimator, pick_incomplete=False):
    old_transform = estimator.create_transformation()
    new_transform = change_instance_splitter_pick_incomplete(
        old_transform, pick_incomplete=pick_incomplete
    )
    estimator.create_transformation = lambda: new_transform


def change_transform_for_predictor(
    transformation: Chain,
    history_length,
    future_length=None,
    time_features=None,
):
    new_transformations = []
    for tr in transformation.transformations:
        if isinstance(tr, InstanceSplitter):
            future_length = (
                future_length
                if future_length is not None
                else tr.future_length
            )
            n_tr = InstanceSplitter(
                target_field=tr.target_field,
                is_pad_field=tr.is_pad_field,
                start_field=tr.start_field,
                forecast_start_field=tr.forecast_start_field,
                train_sampler=tr.train_sampler,
                past_length=history_length,
                future_length=future_length,
                time_series_fields=tr.ts_fields,
                lead_time=tr.lead_time,
                dummy_value=tr.dummy_value,
            )

        elif isinstance(tr, AddTimeFeatures):
            n_tr = AddTimeFeatures(
                start_field=tr.start_field,
                target_field=tr.target_field,
                output_field=tr.output_field,
                time_features=time_features,
                pred_length=future_length,
            )
        elif isinstance(tr, AddAgeFeature):
            n_tr = AddAgeFeature(
                target_field=tr.target_field,
                output_field=FieldName.FEAT_AGE,
                pred_length=future_length,
                log_scale=tr.log_scale,
                dtype=tr.dtype,
            )
        elif isinstance(tr, AddConstFeature):
            n_tr = AddConstFeature(
                target_field=tr.target_field,
                output_field=tr.output_field,
                pred_length=future_length,
                const=tr.const,
                dtype=tr.dtype,
            )
        else:
            n_tr = tr
        new_transformations.append(n_tr)

    return Chain(new_transformations)


def change_instance_splitter_pick_incomplete(
    transformation: Chain, pick_incomplete=False
):
    new_transformations = []
    for tr in transformation.transformations:
        if isinstance(tr, InstanceSplitter):
            new_transformations.append(
                InstanceSplitter(
                    target_field=tr.target_field,
                    is_pad_field=tr.is_pad_field,
                    start_field=tr.start_field,
                    forecast_start_field=tr.forecast_start_field,
                    train_sampler=tr.train_sampler,
                    past_length=tr.past_length,
                    future_length=tr.future_length,
                    time_series_fields=tr.ts_fields,
                    lead_time=tr.lead_time,
                    dummy_value=tr.dummy_value,
                    pick_incomplete=pick_incomplete,
                    output_NTC=tr.output_NTC,
                )
            )
        else:
            new_transformations.append(tr)

    return Chain(new_transformations)


def shorten_dataset(dataset: Dataset, new_length, lead_time):
    def truncate_target(data):
        data = data.copy()
        target = data["target"]
        data["target"] = target[..., -new_length - lead_time :]
        return data

    dataset_trunc = TransformedDataset(
        dataset, transformation=transform.AdhocTransform(truncate_target)
    )

    return dataset_trunc


class AugmentTrain(Transformation):
    def __init__(
        self,
        categorical_fields: Sequence[str],
        alpha: float,
        time_shift: int,
        min_ts_length: int,
    ):
        self.categorical_fields = categorical_fields
        self.alpha = alpha
        assert min_ts_length % time_shift == 0
        self.time_shift = time_shift
        self.min_ts_length = min_ts_length
        self._cache = None

    def __call__(
        self, data_it: Iterable[DataEntry], is_train: bool
    ) -> Iterator[DataEntry]:
        if not is_train:
            print(f"not using augmentation")
            yield from data_it
            return

        if self._cache is None:
            from tqdm import tqdm

            self._cache = [
                ts
                for ts in tqdm(data_it, desc="loading data for augmentation")
            ]

        data = self._cache

        while True:
            i1, i2 = np.random.randint(len(data), size=2)
            d1, d2 = data[i1], data[i2]

            if len(d1["target"]) > len(d2["target"]):
                d1, d2 = d2, d1

            # |--------|
            # |---------------------|

            target1 = d1["target"]
            target2 = d2["target"]
            assert len(target1) > self.min_ts_length
            assert len(target2) > self.min_ts_length

            # suffix _s denotes that everything is in units of "time_shift"
            T = self.time_shift
            min_len_s = self.min_ts_length // T

            len1_s = int(math.floor(len(target1) / T))
            len2_s = int(math.floor(len(target2) / T))

            # number of shifts we can move target1 left and still have enough overlap
            n_left = max(len1_s - min_len_s, 0)
            n_right = max(len2_s - min_len_s, 0)

            shift = np.random.random_integers(-n_left, n_right)
            if shift <= 0:
                target1 = target1[abs(shift) * T :]
            else:
                target2 = target2[shift * T :]
            l = min(len(target1), len(target2))

            target1 = target1[:l].copy()
            target2 = target2[:l].copy()

            mix = d1.copy()
            lam = np.random.beta(self.alpha, self.alpha)

            mean1 = np.nanmean(target1).clip(1e-12, np.inf)
            mean2 = np.nanmean(target2).clip(1e-12, np.inf)

            target1 /= mean1
            target2 /= mean2

            def inv_box_cox(y, bc_lam):
                return np.exp(np.log(bc_lam * y + 1.0) / bc_lam)

            # bc1 = np.random.uniform(0.1, 2.0)
            # bc2 = np.random.uniform(0.1, 2.0)
            bc1 = 1.0
            bc2 = 1.0

            target_mix = lam * inv_box_cox(target2, bc2) + (
                1.0 - lam
            ) * inv_box_cox(target1, bc1)
            # target_mix = lam * target2 + (1.0 - lam) * target1
            log_scale = lam * np.log(mean2) + (1.0 - lam) * np.log(mean1)

            mix["target"] = np.exp(log_scale) * target_mix
            yield mix


class SelectLastN(AveragingStrategy):
    def select_checkpoints(
        self, checkpoints: List[Dict]
    ) -> Tuple[List[str], List[float]]:
        metric_path_tuple = [
            (c["epoch_no"], c["params_path"]) for c in checkpoints
        ]
        top_checkpoints = sorted(metric_path_tuple, reverse=True)[
            : self.num_models
        ]

        # weights of top checkpoints
        weights = [1 / len(top_checkpoints)] * len(top_checkpoints)

        # paths of top checkpoints
        checkpoint_paths = [c[1] for c in top_checkpoints]

        return checkpoint_paths, weights


def get_forecasts_and_series(
    ds, predictor, prediction_length, num_samples
) -> Tuple[List[QuantileForecast], List[pd.Series]]:
    """
    Generate forecasts and truncate them to prediction_length
    """
    forecast_it, ts_it = make_evaluation_predictions(
        ds, predictor=predictor, num_samples=num_samples,
        prediction_length=prediction_length
    )

    forecasts = []
    series = []
    for fcst, ts in tqdm(
        zip(forecast_it, ts_it),
        total=maybe_len(ds),
        desc="computing forecasts",
    ):
        series.append(ts)
        if not isinstance(fcst, QuantileForecast):
            qf = cast(SampleForecast, fcst).to_quantile_forecast(
                ["mean", "0.5"]
            )
        else:
            qf = fcst

        # truncate to prediction_length
        new_fc_array = qf.forecast_array[:, :prediction_length]
        # clip negative to 0
        new_fc_array = new_fc_array.clip(min=0.0)
        qf = QuantileForecast(
            new_fc_array,
            qf.start_date,
            qf.freq,
            qf.forecast_keys,
            qf.item_id,
            qf.info,
        )

        forecasts.append(qf)
    return forecasts, series