from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from tsbench.config import Config
from tsbench.evaluation.utils import loocv_split, num_fitting_processes, run_parallel
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker
from tsbench.surrogate import Surrogate
from .metrics import mrr, ndcg, nrmse, precision_k, smape


class SurrogateEvaluator:
    """
    The surrogate evaluator evaluates the performance of a surrogate model with respect to ranking
    and regression metrics.
    """

    def __init__(self, surrogate: Surrogate, tracker: Tracker, metrics: Optional[List[str]] = None):
        """
        Args:
            surrogate: The surrogate model to evaluate.
            tracker: The collector from which to obtain the data for evaluation.
            metrics: The metrics to evaluate. If not provided, evaluates all metrics.
        """
        self.surrogate = surrogate
        self.tracker = tracker
        self.metrics = metrics

    def run(self) -> pd.DataFrame:
        """
        Runs the evaluation on the surrogate by applying LOOCV on the datasets being trained on.
        Metrics are then provided per test dataset and averaged over all test configurations.

        Returns:
            A data frame with the results for each fold, the metrics being the columns. The rows
                are indexed by the dataset which was left out.
        """
        metrics = run_parallel(
            self._run_on_dataset,
            data=list(loocv_split(self.tracker)),
            num_processes=num_fitting_processes()
        )
        return pd.concat(metrics).set_index("test_dataset")

    def _run_on_dataset(self, data: Tuple[
        Tuple[List[Config], List[Performance], Tuple[List[Config], List[Performance]]]
    ]) -> pd.DataFrame:
        (X_train, y_train), (X_test, y_test) = data

        # Fit model and predict
        self.surrogate.fit(X_train, y_train)
        y_pred = self.surrogate.predict(X_test)

        # Compute metrics
        scores = self._score(y_pred, y_test)
        return scores.assign(test_dataset=X_test[0].dataset.name)

    def _score(self, y_pred: List[Performance], y_true: List[Performance]) -> pd.DataFrame:
        df_pred = Performance.to_dataframe(y_pred)
        df_true = Performance.to_dataframe(y_true)

        if self.metrics is not None:
            df_pred = df_pred[self.metrics]
            df_true = df_true[self.metrics]

        # We extract the NumPy arrays so that indexing is easier. Each metric is computed such that
        # it results in an array of shape [D] where D is the number of metrics.
        columns = df_pred.columns
        y_pred = df_pred.to_numpy()
        y_true = df_true.to_numpy()

        # Return all results
        metrics = {
            "nrmse": nrmse(y_pred, y_true),
            "smape": smape(y_pred, y_true),
            "mrr": mrr(y_pred, y_true),
            **{f"precision_{k}": precision_k(k, y_pred, y_true) for k in (5, 10, 20)},
            "ndcg": ndcg(y_pred, y_true),
        }
        column_index = pd.MultiIndex.from_tuples(
            [(c, m) for m in sorted(metrics) for c in columns]
        )
        values = np.concatenate([metrics[m] for m in sorted(metrics)])
        return pd.DataFrame(np.reshape(values, (1, -1)), columns=column_index)
