from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from tsbench.config import Config, DatasetConfig, ModelConfig
from tsbench.experiments.metrics import Performance
from .base import Recommender
from .generator import CandidateGenerator
from .recommendation import Recommendation
from .registry import register_recommender


@register_recommender("modelfree")
class ModelFreeRecommender(Recommender):
    """
    Recommender that selects a set of configurations according to the lowest joint error in average
    rank. The model-free recommender is only able to minimize a single objective. Recommendations
    are independent of the dataset to recommend for.
    """

    def __init__(
        self,
        generator: Optional[CandidateGenerator] = None,
        maximize: Optional[List[str]] = None,
        minimize: Optional[List[str]] = None,
        focus: Optional[str] = None,
    ):
        """
        Args:
            generator: The generator that generates configurations for recommendations. By default,
                this is the replay candidate generator.
            maximize: The list of performance metrics to maximize.
            minimize: The list of performance metrics to minimize.
            focus: The metric to prefer. Must be either in the list of the metrics to maximize or
                minimize. If not provided, the first metric to minimize is chosen or, if none such
                metric is provided, the first metric to maximize.
        """
        super().__init__(generator, maximize, minimize, focus)
        assert (
            len(self.minimize) == 1 and len(self.maximize) == 0
        ), "Model-free recommender can only minimize a single objective."

        self.metrics: Dict[ModelConfig, np.ndarray]

    def fit(self, configs: List[Config], performances: List[Performance]) -> None:
        super().fit(configs, performances)

        # We need to sort by dataset to have the same ordering for each model config
        ordering = np.argsort([c.dataset.name for c in configs])
        performance_df = Performance.to_dataframe(performances)

        # Extract all metrics
        metrics = defaultdict(list)
        for i in ordering:
            metrics[configs[i].model].append(performance_df.iloc[i][self.minimize].item())
        self.metrics = {k: np.array(v) for k, v in metrics.items()}

    def recommend(
        self,
        dataset: DatasetConfig,
        candidates: Optional[List[ModelConfig]] = None,
        max_count: int = 10,
    ) -> List[Recommendation]:
        # Get model configurations
        model_configs = self.generator.generate(candidates)
        assert all(
            c in self.metrics for c in model_configs
        ), "Model-free recommender can only provide recommendations for known configurations."

        # Greedily pick configurations
        available_choices = list(range(len(model_configs)))
        result = []
        while len(result) < max_count:
            # Pick the configuration which minimizes the loss
            losses = []
            for choice in available_choices:
                # Compute the loss
                all_configs = result + [model_configs[choice]]
                all_performances = np.stack([self.metrics[c] for c in all_configs])
                loss = all_performances.min(0).mean()
                losses.append(loss)

            # Get the index with the lowest loss
            lowest = np.argmin(losses)
            index = available_choices[lowest]
            del available_choices[lowest]
            result.append(model_configs[index])

        return [Recommendation(r, _dummy_performance()) for r in result]


# -------------------------------------------------------------------------------------------------


def _dummy_performance() -> Performance:
    return Performance.from_dict(
        {mm: np.nan for m in Performance.metrics() for mm in [f"{m}_mean", f"{m}_std"]}
    )
