from typing import Dict, List, Tuple
from tsbench.config import Config, ModelConfig
from tsbench.evaluation.utils import num_fitting_processes, run_parallel
from tsbench.evaluation.utils.loocv import loocv_split
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker
from tsbench.recommender import Recommender


class RecommenderEvaluator:
    """
    The recommender evaluator evaluates the performance of recommenders.
    """

    def __init__(self, tracker: Tracker, recommender: Recommender, num_recommendations: int = 10):
        """
        Args:
            tracker: The tracker from which to obtain data and model performances.
            recommender: The recommender to use for obtaining models.
            num_recommendations: The number of recommendations to perform.
        """
        self.tracker = tracker
        self.recommender = recommender
        self.num_recommendations = num_recommendations

    def run(self) -> List[Dict[str, ModelConfig]]:
        """
        Runs the evaluation on all datasets and returns the selected models for each dataset.
        The config evaluator can be used to construct a data frame of performances from the
        configurations.

        Returns:
            The recommeded models. The outer list provides the index of the recommendations, i.e.
                the first item of the list provides all the first recommendations of the
                recommender, etc.
        """
        results = run_parallel(
            self._run_on_dataset,
            data=list(loocv_split(self.tracker)),
            num_processes=num_fitting_processes(),
        )
        recommendations = {k: v for r in results for k, v in r.items()}
        return [
            {k: v[i] for k, v in recommendations.items()} for i in range(self.num_recommendations)
        ]

    def _run_on_dataset(
        self,
        data: Tuple[
            Tuple[List[Config], List[Performance]], Tuple[List[Config], List[Performance]]
        ],
    ) -> Dict[str, List[ModelConfig]]:
        # Extract the data
        (X_train, y_train), (X_test, _) = data
        dataset = X_test[0].dataset

        # Fit the recommender and predict
        self.recommender.fit(X_train, y_train)
        recommendations = self.recommender.recommend(dataset, max_count=self.num_recommendations)

        # Return the recommendations
        return {dataset.name: [r.config for r in recommendations]}
