import random
from typing import List, Optional
import numpy as np
from tsbench.config import Config
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker
from .base import Surrogate
from .registry import register_surrogate


@register_surrogate("random")
class RandomSurrogate(Surrogate):
    """
    The random surrogate simply predicts random performance metrics to act as a baseline.
    """

    def __init__(self, predict: Optional[List[str]] = None, tracker: Optional[Tracker] = None):
        """
        Args:
            predict: The metrics to predict. All if not provided.
            tracker: An optional tracker that can be used to impute latency and number of model
                parameters into model performances.
        """
        super().__init__(tracker)
        self.metrics = predict

    def fit(self, X: List[Config], y: List[Performance]) -> None:
        return None

    def _predict(self, X: List[Config]) -> List[Performance]:
        return [
            Performance.from_dict(
                {
                    key: (
                        np.nan
                        if self.metrics is not None and key not in self.metrics
                        else random.random()
                    )
                    for m in Performance.metrics()
                    for key in (f"{m}_mean", f"{m}_std")
                }
            )
            for _ in X
        ]
