from abc import ABC
from typing import cast, List, Optional
from tsbench.config import Config, DatasetConfig, ModelConfig
from tsbench.experiments.metrics import Performance
from .generator import CandidateGenerator, ReplayCandidateGenerator
from .recommendation import Recommendation
from .utils import argsort_nondominated


class Recommender(ABC):
    """
    A recommender uses a predictor to recommend models and their configurations based on desired
    target outputs. This class implements the general interface.
    """

    def __init__(
        self,
        generator: Optional[CandidateGenerator] = None,
        maximize: Optional[List[str]] = None,
        minimize: Optional[List[str]] = None,
        focus: Optional[str] = None,
    ):
        """
        Args:
            generator: The generator that generates configurations for recommendations. By default,
                this is the replay candidate generator.
            maximize: The list of performance metrics to maximize.
            minimize: The list of performance metrics to minimize.
            focus: The metric to prefer. Must be either in the list of the metrics to maximize or
                minimize. If not provided, the first metric to minimize is chosen or, if none such
                metric is provided, the first metric to maximize.
        """
        # Assertions
        maximize = maximize or []
        minimize = minimize or []
        assert len(maximize) + len(minimize) > 0, "No metrics provided."

        assert focus is None or (
            focus in maximize or focus in minimize
        ), "Focus metric not found in metrics to maximize or minimize."

        # Initialize attributes
        self.generator = generator or ReplayCandidateGenerator()
        self.maximize = maximize
        self.minimize = minimize
        self.focus = focus

    def fit(self, configs: List[Config], _performances: List[Performance]) -> None:
        """
        Fits the recommender, including surrogate model and generator, on the provided
        configurations.

        Args:
            configs: The configurations to train on (the generator typically extracts the unique
                model configurations).
            performances: The performances that the surrogate should fit on. The performances must
                align with the provided configurations.
        """
        self.generator.fit(list({c.model for c in configs}))

    def recommend(
        self,
        dataset: DatasetConfig,
        candidates: Optional[List[ModelConfig]] = None,
        max_count: int = 10,
    ) -> List[Recommendation]:
        """
        This method takes a dataset and a set of constraints and outputs a set of recommendations.
        The recommendations provide both the configurations of the recommended model as well as the
        expected performance.

        Args:
            dataset: The configuration of the dataset for which to recommend a model.
            candidates: A list of model configurations that are allowed to be recommended. If
                `None`, any model configuration is permitted.
            max_count: The maximum number of models to recommend.

        Returns:
            The recommendations which (approximately) satisfy the provided constraints.
        """
        model_configs = self.generator.generate(candidates)
        configs = [Config(m, dataset) for m in model_configs]
        performances = self._get_performances(configs)

        # We construct a data frame, extracting the performance metrics to maximize and minimize.
        # Then, we invert the performance metrics for the metrics to maximize.
        df = Performance.to_dataframe(performances)
        df = df[self.maximize + self.minimize]
        for maximize in self.maximize:
            df[maximize] = -df[maximize]

        # Then, we perform a nondominated sort
        argsort = argsort_nondominated(
            df.to_numpy(),
            dim=df.columns.tolist().index(self.focus) if self.focus is not None else None,
            max_items=max_count,
        )

        # And get the recommendations
        result = []
        for choice in cast(List[int], argsort):
            config = configs[choice]
            recommendation = Recommendation(config.model, performances[choice])
            result.append(recommendation)

        return result

    def _get_performances(self, configs: List[Config]) -> List[Performance]:
        raise NotImplementedError
