import itertools
import math
import random
from typing import Callable, Dict, List, Optional, Tuple, Type
import numpy as np
import pandas as pd
from tsbench.config import Config, DatasetConfig, ModelConfig, TrainConfig
from tsbench.config.model.models import SeasonalNaiveModelConfig
from tsbench.evaluation.utils import loocv_split, num_fitting_processes, run_parallel
from tsbench.experiments.metrics.performance import Metric, Performance
from tsbench.experiments.tracking import Tracker
from tsbench.forecasts import ensemble_forecasts, EnsembleWeighting, evaluate_forecasts, Evaluation
from tsbench.surrogate import Surrogate


class EnsembleEvaluator:
    """
    The ensemble evaluator allows for evaluating the performance of ensembles. Optionally, it uses
    a surrogate to predict the performance of ensemble members.
    """

    def __init__(
        self,
        tracker: Tracker,
        surrogate: Optional[Surrogate] = None,
        ensemble_size: Optional[int] = 10,
        ensemble_weighting: EnsembleWeighting = "uniform",
        config_class: Optional[Type[ModelConfig]] = None,
        max_latency: Optional[float] = None,
    ):
        """
        Args:
            tracker: The tracker from which to obtain pretrained models and forecasts.
            surrogate: The surrogate to use for predicting performance measures.
            ensemble_size: The number of models to use when building an ensemble. If not provided,
                uses as many models as possible.
            ensemble_weighting: The type of ensemble weighting to use for averaging forecasts.
            config_class: The class of models to ensemble. If this is provided, fewer models than
                the given ensemble size might be selected.
            max_latency: The maximum latency of the ensemble. Models are selected greedily. If no
                model can be selected such that the latency constraint is satisfied, the Seasonal
                Naïve model is selected nonetheless. The latency would then exceed the constraint.
        """
        assert (
            max_latency is None or max_latency >= 0.001
        ), "Latency constraint must not be smaller than 1 ms."

        self.tracker = tracker
        self.surrogate = surrogate
        self.ensemble_size = ensemble_size
        self.ensemble_weighting = ensemble_weighting
        self.config_class = config_class
        self.max_latency = max_latency

    def run(self) -> Tuple[pd.DataFrame, Dict[str, List[ModelConfig]]]:
        """
        Runs the evaluation on the data provided via the tracker. The data obtained from the
        tracker is partitioned by the dataset and we run "grouped LOOCV" to compute performance
        metrics on datasets. Metrics on each dataset are then returned as data frame. This function
        runs with as many processes as possible.

        Returns:
            The metrics on the individual datasets.
            The model choices for each dataset.
        """
        results = run_parallel(
            self._run_on_dataset,
            data=list(loocv_split(self.tracker)),
            num_processes=num_fitting_processes(),
        )
        performances = [r[0] for r in results]
        member_mapping = {k: v for r in results for k, v in r[1].items()}

        df = pd.concat(performances).set_index("test_dataset")
        return df, member_mapping

    def _run_on_dataset(
        self,
        data: Tuple[
            Tuple[List[Config], List[Performance]], Tuple[List[Config], List[Performance]]
        ],
    ) -> Tuple[pd.DataFrame, Dict[str, List[ModelConfig]]]:
        # Extract the data
        (X_train, y_train), (X_test, y_test) = data

        # Compute the metrics
        performance, members = self._performance_on_dataset(X_train, y_train, X_test, y_test)

        # Transform into output
        df = Performance.to_dataframe([performance]).assign(test_dataset=X_test[0].dataset.name)
        return df, {X_test[0].dataset.name: members}

    def _performance_on_dataset(
        self,
        X_train: List[Config],
        y_train: List[Performance],
        X_test: List[Config],
        y_test: List[Performance],
    ) -> Tuple[Performance, List[ModelConfig]]:
        # If there are model constr[aints, we restrict the test set
        if self.config_class is not None:
            indices = [
                i
                for i, c in enumerate(X_test)
                if isinstance(c.model, self.config_class)
                and (not isinstance(c.model, TrainConfig) or c.model.training_fraction == 1)
            ]
            X_test = [X_test[i] for i in indices]
            y_test = [y_test[i] for i in indices]

        # Then, we obtain the performance metrics to use for selecting and weighting ensemble
        # members
        if self.surrogate is not None:
            self.surrogate.fit(X_train, y_train)
            performances = self.surrogate.predict(X_test)
        else:
            performances = y_test

        # After that, we sort the performances and pick the ensemble members
        order = self._sort_performances(performances)
        choices = order[: (self.ensemble_size or len(order))]

        # In case no model is selected, seasonal naïve is selected (this should always be
        # fast enough to satisfy the latency constraint unless our measurement is faulty or
        # the implementation inefficient)
        if len(choices) == 0:
            choices = [
                i for i, c in enumerate(X_test) if isinstance(c.model, SeasonalNaiveModelConfig)
            ]

        # Eventually, we get the ensemble members and compute the ensemble performance
        members = [X_test[i].model for i in choices]
        performance = self._get_ensemble_performance(
            members,
            dataset=X_test[0].dataset,
            member_performances=[performances[i] for i in choices],
        )

        return performance, members

    def _sort_performances(self, performances: List[Performance]) -> List[int]:
        # Compute the losses to order the models by their performance
        losses = [p.mean_weighted_quantile_loss.mean for p in performances]
        order = np.argsort(losses).tolist()

        # When there is no constraint, the order of the loss is just the order of the ensemble
        # members to pick
        if self.max_latency is None:
            return order

        # Otherwise, we greedily select models until the constraint is not met anymore
        choices = []
        while True:
            latency = sum(performances[i].latency.mean for i in choices)
            added = False
            for i in range(len(order)):  # pylint: disable=consider-using-enumerate
                if latency + performances[order[i]].latency.mean <= self.max_latency:
                    choices.append(order[i])
                    del order[i]
                    added = True
                    break
            if not added:
                break
        return choices

    def _get_ensemble_performance(
        self,
        models: List[ModelConfig],
        dataset: DatasetConfig,
        member_performances: List[Performance],
    ) -> Performance:
        # First, we need to get the forecasts for all models
        forecasts = [self.tracker.get_forecasts(Config(m, dataset)) for m in models]

        # Then, we want to construct min(#available_choices, 10) different ensembles by randomly
        # choosing models from the configurations without replacement.
        max_choices = np.prod([len(f) for f in forecasts])
        num_choices = min(max_choices, 10)
        pool = itertools.product(*[range(len(f)) for f in forecasts])
        model_combinations = random.sample(list(pool), k=num_choices)

        # Then, we evaluate each of the ensembles
        evaluations = []
        for combination in model_combinations:
            ensembled_forecast = ensemble_forecasts(
                [f[i] for i, f in zip(combination, forecasts)],
                self.ensemble_weighting,
                [p.mean_weighted_quantile_loss.mean for p in member_performances],
            )
            evaluation = evaluate_forecasts(ensembled_forecast, dataset.data.test().evaluation())
            evaluations.append(evaluation)

        # And eventually, we build the resulting performance object
        performance = Evaluation.performance(evaluations)
        performance.num_model_parameters = self._combine_metrics(
            member_performances, lambda p: p.num_model_parameters
        )
        performance.latency = self._combine_metrics(member_performances, lambda p: p.latency)
        performance.training_time = self._combine_metrics(
            member_performances, lambda p: p.training_time
        )
        return performance

    def _combine_metrics(
        self, performances: List[Performance], metric: Callable[[Performance], Metric]
    ) -> Metric:
        return Metric(
            mean=sum(metric(p).mean for p in performances),
            std=math.sqrt(sum(metric(p).std ** 2 for p in performances)),
        )
