from typing import Any, Dict, List, Optional
from tsbench.experiments.tracking import Tracker
from .base import DatasetFeaturesMixin, Surrogate
from .registry import SURROGATE_REGISTRY


def create_surrogate(
    name: str,
    predict: Optional[List[str]],
    tracker: Tracker,
    input_flags: Dict[str, bool],
    **kwargs: Any,
) -> Surrogate:
    """
    Creates a surrogate using the specified parameters.
    """
    assert name in SURROGATE_REGISTRY, f"Unknown surrogate {name}."

    # Build the parameters
    surrogate_cls = SURROGATE_REGISTRY[name]
    args = {"predict": predict, "tracker": tracker, **kwargs}
    if issubclass(surrogate_cls, DatasetFeaturesMixin):
        args.update(input_flags)

    # Initialize the surrogate
    return surrogate_cls(**args)
