from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from tsbench.config import Config, ModelConfig
from tsbench.config.model.models import SeasonalNaiveModelConfig
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker
from .base import DatasetFeaturesMixin, Surrogate
from .registry import register_surrogate
from .transformers import ConfigTransformer, PerformanceTransformer


@register_surrogate("nonparametric")
class NonparametricSurrogate(Surrogate, DatasetFeaturesMixin):
    """
    The nonparametric surrogate predicts a model's performance on a new dataset as the average
    performance across all known datasets. Performances are either predicted as ranks or actual
    values.
    """

    def __init__(
        self,
        use_ranks: bool = False,
        use_simple_dataset_features: bool = True,
        use_seasonal_naive_performance: bool = False,
        use_catch22_features: bool = False,
        predict: Optional[List[str]] = None,
        tracker: Optional[Tracker] = None,
    ):
        """
        Args:
            use_ranks: Whether to average ranks or performances directly.
            use_simple_dataset_features: Whether to use dataset features to predict using a
                weighted average.
            use_seasonal_naive_performance: Whether to use the Seasonal Naïve nCRPS as dataset
                featuers. Requires the cacher to be set.
            use_catch22_features: Whether to use catch22 features for datasets statistics. Ignored
                if `use_dataset_features` is not set.
            predict: The metrics to predict. All if not provided.
            tracker: An optional tracker that can be used to impute latency and number of model
                parameters into model performances.
        """
        super().__init__(tracker)

        self.use_ranks = use_ranks
        self.use_dataset_features = any(
            [
                use_simple_dataset_features,
                use_seasonal_naive_performance,
                use_catch22_features,
            ]
        )
        if self.use_dataset_features:
            self.config_transformer = ConfigTransformer(
                add_model_features=False,
                add_dataset_statistics=use_simple_dataset_features,
                add_seasonal_naive_performance=use_seasonal_naive_performance,
                add_catch22_features=use_catch22_features,
                tracker=tracker,
            )
        self.performance_transformer = PerformanceTransformer(
            metrics=predict,
        )

        # Fitted properties
        self.model_performances_: Dict[ModelConfig, np.ndarray]
        self.dataset_features_: np.ndarray

    def fit(self, X: List[Config], y: List[Performance]) -> None:
        y_numpy = self.performance_transformer.fit_transform(y)

        # For each model configuration, we store all performances, sorted by dataset
        performances = defaultdict(list)
        datasets = set()
        for xx, yy in zip(X, y_numpy):
            datasets.add(xx.dataset)
            performances[xx.model].append({"performance": yy, "dataset": xx.dataset})

        # Then, we assign the model performances and dataset features
        self.model_performances_ = {
            model: np.stack(
                [p["performance"] for p in sorted(data, key=lambda x: x["dataset"].name)]
            )
            for model, data in performances.items()
        }
        if self.use_ranks:
            ranks = np.stack(list(self.model_performances_.values())).argsort(0).argsort(0)
            self.model_performances_ = {
                model: ranks[i] for i, model in enumerate(self.model_performances_)
            }

        # We use the seasonal naive model config here since it is ignored anyway
        if self.use_dataset_features:
            self.dataset_features_ = self.config_transformer.fit_transform(
                [
                    Config(SeasonalNaiveModelConfig, d)
                    for d in sorted(datasets, key=lambda x: x.name)
                ]
            )

    def _predict(self, X: List[Config]) -> List[Performance]:
        if self.use_dataset_features:
            embeddings = self.config_transformer.transform(X)

        results = []
        for i, x in enumerate(X):
            performance = self.model_performances_[x.model]
            if self.use_dataset_features:
                dataset_embedding = embeddings[i][None, :]
                # Compute distances
                distances = euclidean_distances(self.dataset_features_, dataset_embedding)
                similarity = 1 / distances
                similarity[distances == 0] = float("inf")
                # Compute weighted prediction
                weights = similarity / np.sum(similarity)
                results.append((performance * weights).sum(0))
            else:
                results.append(performance.mean(0))
        return self.performance_transformer.inverse_transform(np.stack(results))
