from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from tsbench.config import Config
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker


class Surrogate(ABC):
    """
    This class defines the interface for any surrogate model which attempts to predict performance
    metrics from model configurations. Subclasses may decide to only predict some performance
    metrics.
    """

    def __init__(self, tracker: Optional[Tracker] = None):
        """
        Args:
            tracker: An optional tracker that can be used to impute latency and number of model
                parameters into model performances.
        """
        self.tracker = tracker

    @abstractmethod
    def fit(self, X: List[Config], y: List[Performance]) -> None:
        """
        Uses the provided data to fit a model which is able to predict the target variables from
        the input.

        Args:
            X: The input configurations.
            y: The performance values associated with the input configurations.
        """

    def predict(self, X: List[Config]) -> List[Performance]:
        """
        Predicts the target variables for the given inputs. Typically requires `fit` to be called
        first.

        Args:
            X: The configurations for which to predict performance metrics.

        Returns:
            The predicted performance metrics for the input configurations.
        """
        performances = self._predict(X)
        if self.tracker is None:
            return performances

        # If the cacher is defined, latency and model parameters are set to the true values as they
        # can be simulated easily.
        for x, predicted_performance in zip(X, performances):
            true_performance = self.tracker.get_performance(x)
            # in-place operations
            predicted_performance.latency = true_performance.latency
            predicted_performance.num_model_parameters = true_performance.num_model_parameters
        return performances

    @abstractmethod
    def _predict(self, X: List[Config]) -> List[Performance]:
        pass


class DatasetFeaturesMixin:
    """
    Simple mixin which can be inherited by surrogates to signal that they (optionally) use dataset
    features.
    """
