from typing import List, Optional
from tsbench.config import Config
from tsbench.experiments.metrics import Performance
from tsbench.surrogate import Surrogate
from .base import Recommender
from .generator import CandidateGenerator
from .registry import register_recommender


@register_recommender("surrogate")
class SurrogateRecommender(Recommender):
    """
    The surrogate recommender recommends models by predicting their performance using a surrogate
    model.
    """

    def __init__(
        self,
        surrogate: Surrogate,
        generator: Optional[CandidateGenerator] = None,
        maximize: Optional[List[str]] = None,
        minimize: Optional[List[str]] = None,
        focus: Optional[str] = None,
    ):
        """
        Args:
            surrogate: The surrogate model which predicts metrics from models and their
                configurations. The surrogate will be trained when the `fit` method is called.
            generator: The generator that generates configurations for recommendations. By default,
                this is the replay candidate generator.
            maximize: The list of performance metrics to maximize.
            minimize: The list of performance metrics to minimize.
            focus: The metric to prefer. Must be either in the list of the metrics to maximize or
                minimize. If not provided, the first metric to minimize is chosen or, if none such
                metric is provided, the first metric to maximize.
        """
        super().__init__(generator, maximize, minimize, focus)
        self.surrogate = surrogate

    def fit(self, configs: List[Config], performances: List[Performance]) -> None:
        super().fit(configs, performances)
        self.surrogate.fit(configs, performances)

    def _get_performances(self, configs: List[Config]) -> List[Performance]:
        return self.surrogate.predict(configs)
