import time
from collections import Counter
from typing import Dict, List, Optional, Union

import mlflow
import numpy as np
import pandas as pd
import torch
from autogluon.tabular import TabularPredictor
from autogluon.timeseries.dataset import TimeSeriesDataFrame
from autogluon.timeseries.models.ensemble.abstract_timeseries_ensemble import (
    AbstractTimeSeriesEnsembleModel,
)
from tqdm import tqdm

import atse

from .data import tensor_to_timeseriesdataframe, timeseriesdataframe_to_tensor
from .linear_stacker_regressor import LinearStackerRegressor
from .lrschedulers import ReduceLRWhenUnstable
from .metrics import AbstractMetric
from .scalers import AVAILABLE_SCALERS
from .tabular_models import MODEL_TYPES


class TensorBasedTimeSeriesEnsembleModel(AbstractTimeSeriesEnsembleModel):
    """
    Base class for time series ensemble models that operate on tensors.

    This class handles the transformation from and to TimeSeriesDataFrame objects,
    as well as some other things such as isotonization, detecting and ignoring
    failing base models, and scaling the base model predictions and ensemble
    predictions.
    """

    def __init__(
        self,
        metric: AbstractMetric,
        isotonization="sort",
        ignore_models: Optional[List[str]] = None,
        detect_and_ignore_failures: bool = True,
        scaler: Optional[str] = None,
        greedy_preselected: bool = False,
        verbose: bool = True,
        sparsify: bool = False,
        prune_below: float = 0.05,
        **kwargs,
    ):
        """
        Parameters
        ----------
        metric: AbstractMetric
            The metric to use for scoring the models.
        isotonization: str
            The isotonization method to use (i.e. the algorithm to prevent quantile non-crossing).
            Currently only "sort" is supported. Default is "sort".
        ignore_models: List[str]
            A list of models to ignore. Default is None.
        detect_and_ignore_failures: bool
            Whether to detect and ignore "failed models", defined as models which have a loss that is larger
            than 10x the median loss of all the models. This can be very important for the regression-based
            ensembles, as moving the weight from such a "failed model" to zero can require a long training
            time.  Default is True.
        scaler: str
            Optionally specify a scaler to rescale the input and output data, e.g. by normalizing it with
            the empirical mean and variance. See also `atse.AVAILABLE_SCALERS` for a list of available scalers.
            Default is None.
        greedy_preselected: bool
            By setting this to True, the ensemble will first run a `atse.models.GreedyEnsemble` to generate
            a sparse set of weights, and then only keep those models that have positive weight. This can be
            used to induce sparsity in any model and can thereby lead to faster inference times.
            Default is False.
        verbose: bool
            Whether to print verbose output. Default is True.
        sparsify: bool
            Whether to (attempt to) sparsify the model after training it.
            Default is False.
            See also `prune_below`.
        prune_below: float
            Specifies at which level of importance the model should be sparsified.
            Default is 0.05.
        """
        self.metric = metric
        self.isotonization = isotonization
        self.ignore_models = ignore_models or []
        self.detect_and_ignore_failures = detect_and_ignore_failures
        self.verbose = verbose
        self.sparsify = sparsify
        self.prune_below = prune_below
        self.scaler_str = scaler
        if scaler in AVAILABLE_SCALERS:
            self.scaler = AVAILABLE_SCALERS[scaler]()
        elif scaler is None:
            self.scaler = None
        else:
            raise ValueError(f"Unsupported scaler: {scaler}")

        self.greedy_preselected = greedy_preselected
        if self.greedy_preselected:
            self.greedy_model = GreedyEnsemble(metric=metric, **kwargs)

        # The following can be owerwritten by subclasses if they want to rely on all
        # available quantile information when generating point forecasts.
        self.use_all_quantiles_for_point_forecasts = False

        super().__init__(**kwargs)

    def _get_quantile_levels(
        self,
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
    ) -> List[float]:
        """Extract quantile levels from the columns of the model predictions"""
        _prediction_tsdf = list(model_predictions_per_window.values())[0][0]
        quantile_levels = [float(col) for col in _prediction_tsdf.columns if col != "mean"]
        return quantile_levels

    def _get_prediction_length(
        self,
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
    ) -> List[float]:
        """Extract prediction length from the index of the model predictions"""
        _prediction_tsdf = list(model_predictions_per_window.values())[0][0]
        i0_index = _prediction_tsdf.index.levels[0][0]
        prediction_length = _prediction_tsdf.loc[i0_index].shape[0]
        return prediction_length

    def _detect_failing_models(
        self,
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
        labels: np.ndarray,
        quantile_levels: List[float],
    ) -> List[str]:
        """Detect failing models by comparing their validation loss to the median loss of all models

        If the model has a loss that is 10x larger than the median, we declare it
        failed and ignore it completely."""
        if not self.detect_and_ignore_failures:
            return []
        predictions = timeseriesdataframe_to_tensor(model_predictions_per_window)
        _models = list(model_predictions_per_window.keys())
        base_model_losses = {}
        for i in range(predictions.shape[-1]):
            _pred = predictions[..., i]
            loss = self.metric.compute_metric(labels, _pred, quantile_levels=quantile_levels)
            base_model_losses[_models[i]] = loss

        median_loss = np.median(list(base_model_losses.values()))
        failing_models = []
        for model, loss in base_model_losses.items():
            if loss > 10 * median_loss:
                self.print(f"Ignoring model {model} due to high validation loss: {loss} (median: {median_loss})")
                failing_models.append(model)
        return failing_models

    def _rescale_if_needed(self, model_predictions_per_window, labels_per_window, labels_per_window_past):
        """Rescale the input data and model predictions with `self.scaler`"""
        if self.scaler is None:
            return model_predictions_per_window, labels_per_window
        F = len(labels_per_window)
        _labels_per_window = [None] * F
        _model_predictions_per_window = {k: [None] * F for k in model_predictions_per_window}
        for i in range(F):
            self.scaler.fit(labels_per_window_past[i], target=self.target)
            _labels_per_window[i] = self.scaler.transform(labels_per_window[i], columns=[self.target])
            for m in model_predictions_per_window:
                _model_predictions_per_window[m][i] = self.scaler.transform(
                    model_predictions_per_window[m][i],
                    columns=model_predictions_per_window[m][i].columns,
                )
        labels_per_window = _labels_per_window
        model_predictions_per_window = _model_predictions_per_window
        return model_predictions_per_window, labels_per_window

    def _prepare_fit_ensemble_labels(
        self,
        prediction_length: int,
        labels_per_window: List[TimeSeriesDataFrame],
    ):
        """
        Split the given `labels_per_window` into past data and the labels of interest.

        Having this in a separate function is mostly relevant to save time in the portfolio selection,
        as it allows us to pre-compute the potentially very-demanding `slice_by_timestep` operation once
        and re-use it for multiple calls to `fit_ensemble`.
        """
        labels_per_window_past = [y.slice_by_timestep(None, -prediction_length) for y in labels_per_window]
        labels_per_window = [y.slice_by_timestep(-prediction_length, None) for y in labels_per_window]
        return labels_per_window, labels_per_window_past

    def get_model_importances(self):
        """Specify the importance of each model to the ensemble

        This function is mainly relevant for sparsification.
        By default, each model is equally important.
        """
        return np.ones(len(self.active_models)) / len(self.active_models)

    def make_sparse(self):
        """Sparsify the model (if possible)"""
        raise NotImplementedError("This model cannot currently be sparsified")

    def fit_ensemble(
        self,
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
        labels_per_window: List[TimeSeriesDataFrame],
        time_limit: Optional[int] = None,
        *args,
        do_not_prepare_metric=False,
        do_not_prepare_labels=False,
        labels_per_window_past=None,
        **kwargs,
    ):
        """
        Fit the ensemble model.

        Parameters
        ----------
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]]
            A dictionary mapping model names to lists of model predictions for each window.
        labels_per_window: List[TimeSeriesDataFrame]
            A list of labels for each window.
        time_limit: Optional[int]
            The time limit for fitting the ensemble model; not used right now.
        do_not_prepare_metric: bool
            Whether to skip the metric preparation step, which e.g. saves certain time series
            statistics in order to compute the correct SQL value.
            This should only be done if the metric has already been prepared before!!
            Otherwise, you will get wrong results.
            Default is False.
        do_not_prepare_labels: bool
            Whether to skip the label preparation step, which splits the labels into past and future data.
            This can only be used when also supplying `labels_per_window_past` to the `fit_ensemble` method.
            Default is False.
        labels_per_window_past: List[TimeSeriesDataFrame]
            You can optionally supply a list of past labels for each window here, in order to skip
            the label preparation step (with `do_not_prepare_labels=True`).
            This can save time for repeated calls, such as in the portfolio selection experiment.
            Default is None.
        **kwargs
            Additional keyword arguments to pass to the `_fit_ensemble` method.
        """
        if self.metric.ispointforecast and not self.use_all_quantiles_for_point_forecasts:
            # drop all features that are not the 0.5 quantile
            model_predictions_per_window = {
                k: [tsdf[["0.5"]] for tsdf in tsdfs] for (k, tsdfs) in model_predictions_per_window.items()
            }

        self._active_models = [m for m in model_predictions_per_window.keys() if m not in self.ignore_models]

        if not do_not_prepare_labels:
            assert labels_per_window_past is None
            labels_per_window, labels_per_window_past = self._prepare_fit_ensemble_labels(
                prediction_length=self._get_prediction_length(model_predictions_per_window),
                labels_per_window=labels_per_window,
            )

        if not do_not_prepare_metric:
            self.metric.save_past_metrics(
                data_past=labels_per_window_past,
                target=self.target,
                seasonal_period=self.eval_metric_seasonal_period,
            )

        self.quantile_levels_base = self._get_quantile_levels(model_predictions_per_window)
        self.quantile_levels_out = self.quantile_levels_base if not self.metric.ispointforecast else [0.5]

        model_predictions_per_window, labels_per_window = self._rescale_if_needed(
            model_predictions_per_window, labels_per_window, labels_per_window_past
        )

        labels = timeseriesdataframe_to_tensor(labels_per_window)

        if self.detect_and_ignore_failures:
            failing_models = self._detect_failing_models(
                model_predictions_per_window, labels, self.quantile_levels_base
            )
            self._active_models = [m for m in self._active_models if m not in failing_models]

        model_predictions_per_window = {
            k: v for k, v in model_predictions_per_window.items() if k in self._active_models
        }
        predictions = timeseriesdataframe_to_tensor(model_predictions_per_window)

        if self.greedy_preselected:
            self.greedy_model._fit_ensemble(
                predictions=predictions,
                labels=labels,
                quantile_levels=self.quantile_levels_out,
                labels_per_window_past=labels_per_window_past,
                **kwargs,
            )
            greedy_weights = self.greedy_model.weights
            self.projection_matrix = np.diag(greedy_weights > 0)[:, greedy_weights > 0]
            predictions = predictions @ self.projection_matrix

        out = self._fit_ensemble(
            predictions=predictions,
            labels=labels,
            quantile_levels=self.quantile_levels_out,
            labels_per_window_past=labels_per_window_past,
            model_names=list(model_predictions_per_window.keys()),
            **kwargs,
        )

        if self.sparsify:
            retrain = self.make_sparse()
            if retrain:
                self.print("Re-train sparsified model")
                return self.fit_ensemble(
                    model_predictions_per_window,
                    labels_per_window,
                    time_limit,
                    *args,
                    do_not_prepare_metric=True,
                    do_not_prepare_labels=True,
                    labels_per_window_past=labels_per_window_past,
                    **kwargs,
                )

        return out

    def _fit_ensemble(self, predictions: np.array, labels: np.array, **kwargs):
        raise NotImplementedError

    def predict(
        self,
        model_predictions: Dict[str, Union[TimeSeriesDataFrame]],
        data: Optional[TimeSeriesDataFrame] = None,
        **kwargs,
    ) -> TimeSeriesDataFrame:
        """Predict with the ensemble model."""
        if self.metric.ispointforecast and not self.use_all_quantiles_for_point_forecasts:
            # drop all features that are not the 0.5 quantile
            model_predictions = {k: tsdf[["0.5"]] for (k, tsdf) in model_predictions.items()}

        model_predictions = {k: v for k, v in model_predictions.items() if k in self._active_models}

        if self.scaler is not None:
            if data is None:
                raise ValueError("The timeseries (`data`) must be provided if scaler is used")
            self.scaler.fit(data, target=self.target)
            model_predictions = {k: self.scaler.transform(v, columns=v.columns) for k, v in model_predictions.items()}

        predictions = timeseriesdataframe_to_tensor(model_predictions)
        if self.greedy_preselected:
            predictions = predictions @ self.projection_matrix

        prediction_tensor = self._predict(predictions, data=data, **kwargs)

        assert prediction_tensor.shape[0] == 1
        output_tsdf_template = list(model_predictions.values())[0]
        if self.metric.ispointforecast:
            output_tsdf_template = output_tsdf_template[
                [
                    "0.5",
                ]
            ]
        pred = tensor_to_timeseriesdataframe(output_tsdf_template, prediction_tensor)

        if self.isotonization == "sort":
            _data = np.array(pred)
            _data.sort(axis=1)
            pred[pred.columns] = _data

        if self.scaler is not None:
            pred = self.scaler.inverse_transform(pred)

        return pred

    def _predict(
        self,
        predictions: np.array,
        data: Optional[TimeSeriesDataFrame] = None,
        **kwargs,
    ) -> np.array:
        raise NotImplementedError

    @property
    def active_models(self):
        return self._active_models.copy()

    def print(self, *args, **kwargs):
        if self.verbose:
            print(*args, **kwargs)


class SimpleAverage(TensorBasedTimeSeriesEnsembleModel):
    """A simple average of the predictions of multiple models (mean or median)"""

    def __init__(self, kind: str = "mean", **kwargs):
        """
        Parameters
        ----------
        kind: str
            Specify the kind of average to take, either `"mean"` or `"median"`.
        """
        super().__init__(**kwargs)
        self.kind = kind
        if self.kind not in ["mean", "median"]:
            raise ValueError(f"Unknown aggregation kind {self.kind}")
        self.name = f"SimpleAverage({kind})"

    def _fit_ensemble(self, predictions: np.ndarray, labels: np.ndarray, **kwargs):
        pass

    def _predict(self, predictions: np.ndarray, **kwargs):
        if self.kind == "mean":
            return predictions.mean(axis=-1)
        elif self.kind == "median":
            return np.median(predictions, axis=-1)
        else:
            raise ValueError(f"Unknown aggregation kind {self.kind}")


class BestValidationModel(TensorBasedTimeSeriesEnsembleModel):
    """Selects the model with the best validation score."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.name = "BestValidationModel"

    def _fit_ensemble(
        self,
        predictions: np.ndarray,
        labels: np.ndarray,
        quantile_levels: List[float],
        **kwargs,
    ):
        num_folds, num_items, num_times, num_quantiles, num_models = predictions.shape
        self.best_model_idx = 0
        best_model_loss = self.metric.compute_metric(
            y_true=labels,
            q_pred=predictions[..., 0],
            quantile_levels=quantile_levels,
        )
        for i in range(1, num_models):
            loss = self.metric.compute_metric(
                y_true=labels,
                q_pred=predictions[..., i],
                quantile_levels=quantile_levels,
            )
            if loss < best_model_loss:
                best_model_loss = loss
                self.best_model_idx = i

    def _predict(self, predictions: np.ndarray, **kwargs):
        return predictions[..., self.best_model_idx]

    def get_model_importances(self):
        importances = np.zeros(len(self._active_models))
        importances[self.best_model_idx] = 1
        return importances

    def make_sparse(self):
        # This model is already sparse so just silently pass here; no need to retrain
        return False

    @property
    def active_models(self):
        return [self._active_models[self.best_model_idx]]


class PerformanceWeightedAverage(TensorBasedTimeSeriesEnsembleModel):
    """Averages the models based on their performance on the validation data."""

    def __init__(self, kind="sqr", normalize_losses=False, **kwargs):
        """
        Parameters
        ----------
        kind: str
            The kind of weighting to use. Can be one of:
            - "inv": inverse weighting with g(loss) = 1/loss
            - "sqr": square inverse weighting with g(loss) = (1/loss)^2 (default)
            - "exp": exponential weighting with g(loss) = exp(1/loss)
        normalize_losses: bool
            If set to true, we first normalize all losses such that they sum to one,
            with losses = losses / losses.sum().
        """
        super().__init__(**kwargs)
        if kind not in ["sqr", "inv", "exp"]:
            raise ValueError(f"Unknown weighting kind {kind}")
        self.name = f"PerformanceWeightedAverage({kind})"
        self.kind = kind
        self.normalize_losses = normalize_losses

    def _fit_ensemble(
        self,
        predictions: np.ndarray,
        labels: np.ndarray,
        quantile_levels: List[float],
        **kwargs,
    ):
        num_folds, num_items, num_times, num_quantiles, num_models = predictions.shape
        losses = []
        for i in range(num_models):
            loss = self.metric.compute_metric(
                y_true=labels,
                q_pred=predictions[..., i],
                quantile_levels=quantile_levels,
            )
            losses.append(loss)
        losses = np.array(losses)
        self.losses = losses

        if self.normalize_losses:
            losses = losses / losses.sum()
        if self.kind == "inv":
            self.weights = 1 / losses
        elif self.kind == "sqr":
            self.weights = (1 / losses) ** 2
        elif self.kind == "exp":
            self.weights = np.exp(1 / losses)
        else:
            raise ValueError(f"Unknown weighting kind {self.kind}")
        self.weights = self.weights / self.weights.sum()

    def make_sparse(self):
        importances = self.get_model_importances()
        for i in range(len(importances)):
            if importances[i] < self.prune_below:
                self.weights[i] = 0
        self.weights /= self.weights.sum()
        return False

    def get_model_importances(self):
        return self.weights

    def _predict(self, predictions: np.ndarray, **kwargs):
        return predictions @ self.weights

    @property
    def active_models(self):
        return [m for m, w in zip(self._active_models, self.weights) if w > 0]


class GreedyEnsemble(TensorBasedTimeSeriesEnsembleModel):
    """Weighted average computed with the greedy optimization (Caruana et al., 2004)."""

    def __init__(
        self,
        ensemble_size: int = 100,
        early_stopping: bool = False,
        tqdm: bool = False,
        log: bool = False,
        **kwargs,
    ):
        """
        Parameters
        ----------
        ensemble_size: int
            The size of the ensemble to be computed.
        early_stopping: bool
            If set to true, the greedy optimization will stop early if the ensemble performance would decrease.
            Default is False.
        tqdm: bool
            Whether to show a progress bar.
            Default is False.
        """
        super().__init__(**kwargs)
        self.name = f"GreedyEnsemble({ensemble_size}{', earlystopping' if early_stopping else ''})"
        self.ensemble_size = ensemble_size
        self.early_stopping = early_stopping
        self.tqdm = tqdm

    def _fit_ensemble(
        self,
        predictions: np.ndarray,
        labels: np.ndarray,
        quantile_levels: List[float],
        *args,
        **kwargs,
    ):
        chosen_models = []
        ensemble_prediction = 0 * predictions[..., 0]
        best_score = -999999
        n_models = predictions.shape[-1]
        bar = range(self.ensemble_size)
        if self.tqdm:
            bar = tqdm(bar)
        for _ in bar:
            scores = np.zeros(n_models)
            s = len(chosen_models)
            weighted_ensemble_prediction = (s / float(s + 1)) * ensemble_prediction
            for m in range(predictions.shape[-1]):
                _pred = weighted_ensemble_prediction + (1.0 / float(s + 1)) * predictions[..., m]
                scores[m] = -self.metric.compute_metric(
                    q_pred=_pred,
                    y_true=labels,
                    quantile_levels=quantile_levels,
                )

            if self.early_stopping and (scores < best_score).all():
                # Adding this to the ensemble does not improve anything
                break

            best_score = np.nanmax(scores)
            all_best = np.argwhere(scores == np.nanmax(scores)).flatten()
            best = np.random.choice(all_best)

            chosen_models.append(best)
            ensemble_prediction *= s / (s + 1)
            ensemble_prediction += predictions[..., best] / (s + 1)

        ensemble_size = len(chosen_models)
        counter = Counter(chosen_models).most_common()
        weights = np.zeros((n_models,), dtype=float)
        for model_idx, model_count in counter:
            weight = model_count / ensemble_size
            weights[model_idx] = weight
        if np.sum(weights) < 1:
            weights = weights / np.sum(weights)
        self.weights = weights

    def get_model_importances(self):
        return self.weights

    def _predict(self, predictions: np.ndarray, **kwargs):
        prediction = np.einsum("fitqm,m->fitq", predictions, self.weights)
        return prediction

    def make_sparse(self):
        # this model is already sparse so just silently pass here; no need to retrain
        return False

    @property
    def active_models(self):
        return [m for m, w in zip(self._active_models, self.weights) if w > 0]


class LinearEnsemble(TensorBasedTimeSeriesEnsembleModel):
    """Linear stacker models"""

    def __init__(
        self,
        *args,
        # model kwargs:
        weights_per: str = "m",
        weight_transform: str = "softmax",
        # optimizer kwargs:
        optimizer: str = "adam",
        optimizer_kwargs: Dict = {},
        epochs: Optional[int] = None,
        lrscheduler: Optional[bool] = True,
        early_stopping: Optional[bool] = True,
        tolerance_grad: Optional[float] = 1e-5,
        tolerance_change: Optional[float] = 1e-8,
        time_limit: Optional[int] = None,
        # utilities:
        tqdm: Optional[bool] = False,
        log: Optional[bool] = False,
        # experimental kwargs (EXPERIMENTAL):
        rank: Optional[int] = None,
        clusters: Optional[int] = None,
        **kwargs,
    ):
        """Linear stacker models for time series forecasting.

        Parameters
        ----------
        weights_per: str
            How to compute the weights. Can be any combination of
            - "m": one weight per model
            - "q": one weight per quantile; can be used twice
            - "i": one weight per item
            - "t": one weight per timestep
            Examples:
            - "mq": one weight per model and quantile
            - "miq": one weight per model, quantile and item
            - "mtq": one weight per model, quantile and timestep
            - "miqq": one weight per model and "across-quantiles", i.e. all quantile forecasts are considered in order to compute the forecast
            "m" always needs to be chosen.
        weight_transform: str
            Can be one of:
            - "softmax": apply a softmax -> weights are all positive and sum to one
            - "exp": apply an exponential transformation -> weights are all positive
            - "norm": normalize the weights -> weights sum to one
            - None: no transformation
        optimizer: str
            The optimizer to use for the regression. Can be one of:
            - "adam": Adam optimizer
            - "lbfgs": L-BFGS optimizer
            - "lbfgs+ls": L-BFGS optimizer with line search
        optimizer_kwargs: dict
            The keyword arguments to pass to the optimizer.
        epochs: int
            The number of epochs to train the regression model for.
            If not provided, a default value will be used.
        lrscheduler: bool
            By default this model uses a custom learning rate scheduler, which
            decreases the learning rate whenever the loss starts oscillating.
            Setting this to False disables the scheduler.
        early_stopping: bool
            If set to true, the optimization will stop once a stopping criterion is met.
            See also `tolerance_grad` and `tolerance_change` which provide tolerances
            for this stopping criterion.
            Default is True.
        tolerance_grad: float
            The tolerance for the gradient.
            Default is 1e-5.
        tolerance_change: float
            The tolerance for the change in the loss.
            Default is 1e-8.
        tqdm: bool
            Whether to show a progress bar.
            Default is False.
        log: bool
            Whether to log the ensemble weights to mlflow.
            Default is False.
        rank: int
            EXPERIMENTAL!
            Can be used to specify the rank of a low-rank weight parameterization.
            No convincing results so far. Use at your own risk!
        clusters: int
            EXPERIMENTAL!
            Can be used to specify the number of "clusters" of a particular weight
            parameterization. No convincing results so far. Use at your own risk!

        Do also check the docstring of the parent class
        `TensorBasedTimeSeriesEnsembleModel` as it provides many other useful options,
        such as `sparsify` or `ignore_models`!
        """
        # First let's catch a common mistake I made
        if "lr" in kwargs:
            s = "The learning rate needs to be passed inside `optimizer_kwargs`"
            raise ValueError(s)

        super().__init__(**kwargs)
        weight_transform_str = "noactivation" if weight_transform is None else weight_transform
        self.name = f"LinearEnsemble({weights_per}, {weight_transform_str}"
        if self.sparsify:
            self.name += ", sparsified"
        self.name += ")"

        # The main model description
        self.weights_per = weights_per
        self.weight_transform = weight_transform
        if not (rank is None or clusters is None):
            raise ValueError("Only one of `rank` and `clusters` can be set!")
        self.rank = rank  # EXPERIMENTAL
        self.clusters = clusters  # EXPERIMENTAL

        # Optimizer-related things
        self.optimizer_str = optimizer.lower()
        self.optimizer = {
            "adam": torch.optim.Adam,
            "lbfgs": torch.optim.LBFGS,
            "lbfgs+ls": torch.optim.LBFGS,
        }.get(self.optimizer_str)
        self.optimizer_kwargs = {
            "adam": {
                "lr": 1e-1,
            },
            "lbfgs": {
                "lr": 1.0,
                "tolerance_grad": tolerance_grad / 10,
                "tolerance_change": tolerance_change / 10,
            },
            "lbfgs+ls": {
                "line_search_fn": "strong_wolfe",
                "tolerance_grad": tolerance_grad / 10,
                "tolerance_change": tolerance_change / 10,
            },
        }.get(self.optimizer_str.lower()) | optimizer_kwargs
        self.epochs = (
            epochs
            if epochs is not None
            else {
                "adam": 100_000,
                "lbfgs": 1_000,
                "lbfgs+ls": 1_000,
            }.get(self.optimizer_str.lower())
        )
        self.early_stopping = early_stopping
        self.tolerance_grad = tolerance_grad
        self.tolerance_change = tolerance_change
        self.scheduler = (
            (
                lambda optimizer: ReduceLRWhenUnstable(
                    optimizer,
                    factor=0.5,
                    patience=500,
                    min_fraction=0.25,
                    # cooldown=500,
                    # min_lr=1e-4,
                )
            )
            if lrscheduler
            else None
        )
        self.time_limit = time_limit

        # Other parameters
        self.tqdm = tqdm
        self.log = log

        if "qq" in self.weights_per:
            self.use_all_quantiles_for_point_forecasts = True

    def _gather_flat_grad(self, params):
        """
        from https://github.com/pytorch/pytorch/blob/1ff226d88c70e66434ed9ce50deb5ae09e7e766c/torch/optim/lbfgs.py#L262-L274
        """
        views = []
        for p in params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            elif p.grad.is_sparse:
                view = p.grad.to_dense().view(-1)
            else:
                view = p.grad.view(-1)
            if torch.is_complex(view):
                view = torch.view_as_real(view).view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _gather_max_param(self, params):
        out = -torch.inf
        for p in params:
            out = max(out, p.max())
        return out

    def _fit_ensemble(
        self,
        predictions: np.ndarray,
        labels: np.ndarray,
        quantile_levels: List[float],
        **kwargs,
    ):
        _predictions = torch.tensor(predictions, dtype=torch.float64)
        _labels = torch.tensor(labels, dtype=torch.float64)

        self.regressor = LinearStackerRegressor(
            predictions.shape,
            n_output_quantiles=len(self.quantile_levels_out),
            weights_per=self.weights_per,
            rank=self.rank,
            clusters=self.clusters,
            weight_transform=self.weight_transform,
        )

        optimizer = self.optimizer(self.regressor.parameters(), **self.optimizer_kwargs)
        epochs = self.epochs
        if self.scheduler is not None:
            scheduler = self.scheduler(optimizer)

        def closure():
            """Compute the predictions and return the resulting loss.

            Putting this into a closure is necessary to use LBFGS as it calls this function multiple times.
            """
            optimizer.zero_grad()
            pred = self.regressor(_predictions)
            loss = self.metric.compute_metric(
                q_pred=pred,
                y_true=_labels,
                quantile_levels=quantile_levels,
            )
            loss.backward()
            return loss

        _last_loss = torch.inf
        n_small_loss_change = 0
        start_time = time.time()
        with tqdm(range(epochs), disable=not self.tqdm) as bar:
            for e in bar:
                # The actual optimization step:
                loss = optimizer.step(closure)

                # Update the sceduler (if there is one):
                if self.scheduler is not None:
                    scheduler.step(loss)

                # Log the loss to mlflow:
                if self.log and mlflow.active_run():
                    mlflow.log_metric("loss", loss.item(), step=e)

                # Provide an informative tqdm description:
                max_grad = self._gather_flat_grad(self.regressor.parameters()).abs().max()
                raw_loss_diff = _last_loss - loss
                rel_loss_diff = raw_loss_diff / loss
                bar.set_description(
                    " | ".join(
                        (
                            f"loss: {loss.item():.4f}",
                            f"max_grad: {max_grad:.2e}",
                            f"rel loss diff: {rel_loss_diff:.2e}",
                            f"lr: {optimizer.param_groups[0]['lr']:.2e}",
                        )
                    )
                )

                # Time limit:
                if self.time_limit is not None:
                    if time.time() - start_time > self.time_limit:
                        break

                # Early stopping:
                if self.early_stopping:
                    small_grad = max_grad < self.tolerance_grad
                    not_huge_loss_change = rel_loss_diff.abs() < 1e-4
                    if small_grad and not_huge_loss_change:
                        break

                    small_loss_change = rel_loss_diff.abs() < self.tolerance_change
                    if small_loss_change:
                        n_small_loss_change += 1
                        if n_small_loss_change >= 2:
                            break
                    else:
                        n_small_loss_change = 0
                _last_loss = loss

        if self.log and mlflow.active_run():
            mlflow.log_metric("epochs_until_convergence", e + 1)

    def get_model_importances(self):
        """The importance of each model is defined as its average absolute weight,
        normalized such that they sum to one."""
        w = self.regressor.weight_model()
        weights = w.reshape(-1, w.shape[-1]).abs().mean(axis=0)
        return (weights / weights.sum()).detach().numpy()

    def make_sparse(self, retrain=True):
        """Sparsify the model by ignoring all models with low-enough importance."""
        importances = self.get_model_importances()
        indices_to_drop = []
        for i in range(len(importances))[::-1]:
            if importances[i] < self.prune_below:
                self.ignore_models.append(self._active_models.pop(i))
                self.print(f"Drop model {self.ignore_models[-1]}")
                indices_to_drop.append(i)

        self.sparsify = False  # to avoid re-sparsification
        return True

    def _predict(self, predictions: np.ndarray, **kwargs):
        _predictions = torch.tensor(predictions, requires_grad=False, dtype=torch.float64)
        return self.regressor(_predictions).detach().numpy()

    def plot_weights(self, *args, **kwargs):
        return self.regressor.weight_model.plot(*args, **kwargs)


class StackedEnsemble(TensorBasedTimeSeriesEnsembleModel):
    def __init__(
        self,
        metric,
        target,
        base_models,
        stacker_model,
        retrain: bool = True,
        base_model_kwargs: Optional[dict] = {},
        skip_connections: bool = False,
        **kwargs,
    ):
        """
        Parameters
        ----------
        stacker_model: tuple (model_name, model_kwargs)
            The stacker model to be used. (L3 model)
        base_models: list of tuples (model_name, model_kwargs)
            The base models to be used in the ensemble. (L2 models)
        base_model_kwargs: dict
            This can be used for convenience to pass a set of shared kwargs to the
            base models.
        retrain: bool
            Whether to retrain the base models or not after fitting the L3 model.
            Default is True.
        skip_connections: bool
            Whether to also pass the L1 model outputs to the L3 model.
            Default is False.
        """
        super().__init__(metric=metric, target=target, **kwargs)
        self.name = "StackedEnsemble"
        _base_models = [atse.MODELS[k](metric=metric, target=target, **base_model_kwargs, **v) for k, v in base_models]
        self.base_models = {(i, model.name): model for i, model in enumerate(_base_models)}
        self.active_base_models = self.base_models.keys()
        self.stacker_model = atse.MODELS[stacker_model[0]](
            metric=metric,
            target=target,
            **stacker_model[1],
            detect_and_ignore_failures=False,
        )
        self.retrain = retrain
        self.skip_connections = skip_connections

    def _train_val_split(
        self,
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
        labels_per_window: List[TimeSeriesDataFrame],
    ):
        """Split the training data into "train" and "validation".

        "train" will be used to train the L2 models.
        "validation" will be used to train the L3 model.
        If `self.retrain` is set, L2 models will be retrained on both "train" and "val".
        """
        train_predictions = {k: v[:-1] for k, v in model_predictions_per_window.items()}
        train_labels = labels_per_window[:-1]
        val_predictions = {k: v[-1] for k, v in model_predictions_per_window.items()}
        val_labels = labels_per_window[-1]
        return {
            "train": (train_predictions, train_labels),
            "val": (val_predictions, val_labels),
        }

    def _update_active_base_models(self):
        """
        Update the list of "active" models in order to save time during re-training and inference.
        """
        stacker_model_importances = self.stacker_model.get_model_importances()
        self.active_base_models = [
            k for i, k in enumerate(self.base_models.keys()) if stacker_model_importances[i] > 0
        ]

    def fit_ensemble(
        self,
        model_predictions_per_window: Dict[str, List[TimeSeriesDataFrame]],
        labels_per_window: List[TimeSeriesDataFrame],
        **kwargs,
    ):
        # Remove ignored models from model predictions
        model_predictions_per_window = {
            m: p for m, p in model_predictions_per_window.items() if m not in self.ignore_models
        }
        self._active_models = list(model_predictions_per_window.keys())

        splits = self._train_val_split(model_predictions_per_window, labels_per_window)

        # 1. Train all ensembles on the training data
        train_predictions, train_labels = splits["train"]
        # 1.1. Prepare some quantities to make things faster
        train_labels_per_window, train_labels_per_window_past = self._prepare_fit_ensemble_labels(
            prediction_length=self._get_prediction_length(train_predictions),
            labels_per_window=train_labels,
        )
        self.metric.save_past_metrics(
            data_past=train_labels_per_window_past,
            target=self.target,
            seasonal_period=self.eval_metric_seasonal_period,
        )
        # 1.2. Train the models
        n_models = len(self.base_models)
        self.print(f"Train {n_models} base models")
        for i, (mname, model) in enumerate(self.base_models.items()):
            self.print(f"[{i+1}/{n_models}] {mname}")
            model.metric = self.metric
            model.fit_ensemble(
                model_predictions_per_window=train_predictions,
                labels_per_window=train_labels_per_window,
                labels_per_window_past=train_labels_per_window_past,
                do_not_prepare_labels=True,
                do_not_prepare_metric=True,
            )

        # 2. Compute predictions on the validation data
        val_predictions, val_labels = splits["val"]
        prediction_length = self._get_prediction_length(model_predictions_per_window)
        val_labels_past = val_labels.slice_by_timestep(None, -prediction_length)
        preds = {}
        for mname, model in self.base_models.items():
            pred = model.predict(val_predictions, val_labels_past)
            preds[mname] = [pred]

        if self.skip_connections:
            preds = preds | {k: [v] for k, v in val_predictions.items()}

        # 3. Fit the stacker model
        self.print(f"Train stacker model ({self.stacker_model.name})")
        self.stacker_model.fit_ensemble(preds, [val_labels], **kwargs)
        self._update_active_base_models()

        # 4. If retrain=True, retrain the base models on the full data
        if self.retrain:
            # 4.1. Prepare some quantities to make things faster
            labels_per_window, labels_per_window_past = self._prepare_fit_ensemble_labels(
                prediction_length=self._get_prediction_length(model_predictions_per_window),
                labels_per_window=labels_per_window,
            )
            self.metric.save_past_metrics(
                data_past=labels_per_window_past,
                target=self.target,
                seasonal_period=self.eval_metric_seasonal_period,
            )
            # 4.2. Retrain the models
            n_models = len(self.active_base_models)
            self.print(f"Retrain {n_models} base models")
            for i, mkey in enumerate(self.active_base_models):
                self.print(f"[{i+1}/{n_models}]: {mkey}")
                model = self.base_models[mkey]
                model.metric = self.metric
                model.fit_ensemble(
                    model_predictions_per_window=model_predictions_per_window,
                    labels_per_window=labels_per_window,
                    labels_per_window_past=labels_per_window_past,
                    do_not_prepare_labels=True,
                    do_not_prepare_metric=True,
                )

    def predict(
        self,
        model_predictions: Dict[str, Union[TimeSeriesDataFrame]],
        data: Optional[TimeSeriesDataFrame] = None,
        **kwargs,
    ) -> TimeSeriesDataFrame:
        preds = {}
        empty_pred = list(model_predictions.values())[0].copy()
        for mname, model in self.base_models.items():
            if mname in self.active_base_models:
                preds[mname] = model.predict(model_predictions, data)
            else:
                preds[mname] = empty_pred
        if self.skip_connections:
            preds = preds | model_predictions
        return self.stacker_model.predict(preds, data, **kwargs)

    @property
    def active_models(self):
        l3_model_inputs = self.stacker_model.active_models
        # some of these are l2 models, some are l1 models
        l2s = [m for m in l3_model_inputs if m in self.base_models]
        l1s = [m for m in l3_model_inputs if m not in self.base_models]
        # now each of the l2s also have a set of active l1 models
        for l2 in l2s:
            l1s += self.base_models[l2].active_models
        # deduplicate the l1s
        return list(set(l1s))

    def get_l2_models_with_importances(self):
        l2_model_names = [m_name for m_id, m_name in self.base_models]
        all_unique = len(l2_model_names) == len(set(l2_model_names))
        assert all_unique
        l2_model_weights = {
            m_name: w for (m_id, m_name), w in zip(self.base_models, self.stacker_model.get_model_importances())
        }
        return l2_model_weights

    def get_model_importances(self):
        l3_importances = self.stacker_model.get_model_importances()
        l2_models = self.base_models
        rescaled_l2_importances = [
            {l1_model: i3 * i2 for l1_model, i2 in zip(l2_model._active_models, l2_model.get_model_importances())}
            for (_, l2_model), i3 in zip(l2_models.items(), l3_importances)
        ]
        l1_importances = {k: 0 for k in self._active_models}
        for d in rescaled_l2_importances:
            for key, value in d.items():
                l1_importances[key] += value
        return [l1_importances[m] for m in self._active_models]


class AGTabularStackerModel(TensorBasedTimeSeriesEnsembleModel):
    """Stacker model which uses tabular models from autogluon tabular"""

    def __init__(
        self,
        stacker_model_str: str,
        hyperparameters: Optional[Dict] = None,
        per_quantile: bool = False,
        early_stopping: bool = False,
        time_limit: Optional[int] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        stacker_model_str: str
            The name of the autogluon stacker model to be used.
            For a list of available names, you can check
            `autogluon.tabular.trainer.model_presets.presets.MODEL_TYPES`.
        per_quantile: bool
            Whether to train a separate model for each quantile. Default is False.
        early_stopping: bool
            Whether to use the last validation window for early stopping. Default is False.
        hyperparameters: dict
            A dictionary of hyperparameters to be passed to the stacker model.
            Default is None.
        """
        super().__init__(**kwargs)
        self.stacker_model_str = stacker_model_str
        self.per_quantile = per_quantile
        self.early_stopping = early_stopping
        self.time_limit = time_limit
        if time_limit is not None and not early_stopping:
            # Warn the user that the time limit will be ignored
            self.print("Warning: time_limit is set but early_stopping is False. " "The time limit will be ignored.")
        self.hyperparameters = hyperparameters if hyperparameters is not None else {}
        self.name = f"AGTabularStackerModel({self.stacker_model_str}"
        if self.scaler_str is not None:
            self.name += f", {self.scaler_str}"
        if per_quantile:
            self.name += ", pq"
        if early_stopping:
            self.name += ", es"
        if hyperparameters is not None:
            self.name += f", {hyperparameters}"
        self.name += ")"

        if not self.per_quantile:
            self.use_all_quantiles_for_point_forecasts = True

    def _get_stacker_model(self, quantile_levels: list[float]):
        return MODEL_TYPES[self.stacker_model_str](
            problem_type="quantile",
            hyperparameters={
                "ag_args_fit": {"quantile_levels": quantile_levels},
                **self.hyperparameters,
            },
        )

    def _fit_ensemble(
        self,
        predictions: np.ndarray,
        labels: np.ndarray,
        quantile_levels: List[float],
        **kwargs,
    ):
        F, I, H, Q, M = predictions.shape
        # Early stopping is disabled if only one fold is available
        if self.early_stopping and F > 1:
            X = pd.DataFrame(predictions[:-1, ...].reshape((F - 1) * I * H, Q * M))
            y = pd.DataFrame(labels[:-1, ...].reshape((F - 1) * I * H))
            X_val = pd.DataFrame(predictions[-1:, ...].reshape(1 * I * H, Q * M))
            y_val = pd.DataFrame(labels[-1:, ...].reshape(1 * I * H))
        else:
            X = pd.DataFrame(predictions.reshape(F * I * H, Q * M))
            y = pd.DataFrame(labels.reshape(F * I * H))
            X_val = None
            y_val = None

        notna_mask = y.notna().to_numpy().ravel()
        X = X[notna_mask]
        y = y[notna_mask]
        if X_val is not None:
            notna_mask_val = y_val.notna().to_numpy().ravel()
            X_val = X_val[notna_mask_val]
            y_val = y_val[notna_mask_val]

        if not self.per_quantile:
            self.stacker_model = self._get_stacker_model(quantile_levels=quantile_levels)
            self.stacker_model.fit(
                X=X,
                y=y,
                X_val=X_val,
                y_val=y_val,
                time_limit=self.time_limit,
                verbosity=4,
            )
        else:
            self.stacker_models = []
            for i, q in enumerate(quantile_levels):
                model = self._get_stacker_model(quantile_levels=[q])
                _X = pd.DataFrame(X.to_numpy().reshape(-1, I, H, Q, M)[:, :, :, i, :].reshape(-1, M))
                _X_val = X_val
                if _X_val is not None:
                    _X_val = pd.DataFrame(_X_val.to_numpy().reshape(-1, I, H, Q, M)[:, :, :, i, :].reshape(-1, M))
                time_limit = self.time_limit / Q if self.time_limit is not None else None
                model.fit(X=_X, y=y, X_val=_X_val, y_val=y_val, time_limit=time_limit)
                self.stacker_models.append(model)

    def _predict(self, predictions: np.ndarray, **kwargs):
        preds = []
        if self.per_quantile:
            for i, _ in enumerate(self.quantile_levels_out):
                F, I, H, Q, M = predictions.shape
                X = predictions[:, :, :, i, :].reshape(F * I * H, M)
                X = pd.DataFrame(X)
                pred = self.stacker_models[i].predict(X)
                if isinstance(pred, pd.DataFrame):
                    pred = pred.values
                preds.append(pred.reshape(F, I, H))
            pred = np.stack(preds, axis=-1)
            return pred
        else:
            F, I, H, Q_in, M = predictions.shape
            X = predictions.reshape(F * I * H, Q_in * M)
            X = pd.DataFrame(X)
            pred = self.stacker_model.predict(X)
            if isinstance(pred, pd.DataFrame):
                pred = pred.values
            Q_out = len(self.quantile_levels_out)
            return pred.reshape(F, I, H, Q_out)


class AGTabularPredictorStackerModel(TensorBasedTimeSeriesEnsembleModel):
    """Stacker model which uses TabularPredictor from autogluon tabular

    EXPERIMENTAL
    """

    def __init__(
        self,
        hyperparameters: Optional[Dict] = None,
        presets: Optional[str] = None,
        time_limit: Optional[int] = None,
        per_quantile: bool = False,
        **kwargs,
    ):
        """
        Parameters
        ----------
        hyperparameters: dict
            A dictionary of hyperparameters that specifies the predictor, passed to `TabularPredictor.fit`.
            Default is None.
        presets: str
            A string that specifies a valid `preset` value for `TabularPredictor.fit`.
            Default is None.
        time_limit: int
            The time limit for training in seconds, passed to `TabularPredictor.fit`.
            Default is None.
        per_quantile: bool
            Whether to train a separate model for each quantile. Default is False.
        """
        super().__init__(**kwargs)
        self.hyperparameters = hyperparameters
        self.presets = presets
        self.time_limit = time_limit
        self.per_quantile = per_quantile
        self.name = f"AGTabularPredictorStackerModel"

    def _fit_ensemble(
        self,
        predictions: np.ndarray,
        labels: np.ndarray,
        quantile_levels: List[float],
        **kwargs,
    ):
        if not self.per_quantile:
            self.predictor = TabularPredictor(
                label="target",
                problem_type="quantile",
                quantile_levels=quantile_levels,
                verbosity=1,
            )

            F, I, H, Q, M = predictions.shape
            X = predictions.reshape(F * I * H, Q * M)
            y = labels.reshape(F * I * H)
            df = pd.DataFrame(X)
            df["target"] = y
            self.predictor.fit(
                df,
                hyperparameters=self.hyperparameters,
                presets=self.presets,
                time_limit=self.time_limit,
            )
        else:
            self.predictors = [
                TabularPredictor(
                    label="target",
                    problem_type="quantile",
                    quantile_levels=[q],
                    verbosity=1,
                )
                for q in quantile_levels
            ]

            F, I, H, Q, M = predictions.shape
            for i, _ in enumerate(quantile_levels):
                X = predictions[:, :, :, i, :].reshape(F * I * H, M)
                y = labels[:, :, :].reshape(F * I * H)
                df = pd.DataFrame(X)
                df["target"] = y
                self.predictors[i].fit(
                    df,
                    hyperparameters=self.hyperparameters,
                    presets=self.presets,
                    time_limit=self.time_limit,
                )

    def _predict(self, predictions: np.ndarray, **kwargs):
        if not self.per_quantile:
            F, I, H, Q, M = predictions.shape
            X = predictions.reshape(F * I * H, Q * M)
            df = pd.DataFrame(X)
            pred = self.predictor.predict(df)
            if isinstance(pred, pd.DataFrame):
                pred = pred.values
            return pred.reshape(F, I, H, Q)
        else:
            preds = []
            F, I, H, Q, M = predictions.shape
            for i, _ in enumerate(self.quantile_levels):
                X = predictions[:, :, :, i, :].reshape(F * I * H, M)
                df = pd.DataFrame(X)
                pred = self.predictors[i].predict(df)
                if isinstance(pred, pd.DataFrame):
                    pred = pred.values
                preds.append(pred.reshape(F, I, H))
            pred = np.stack(preds, axis=-1)
            return pred
