
import inspect

from problems.base import BiLevelBaseProblem
from problems.synthetic import (
    GPPriorSamples, BraninGoldstein, CamelBranin, DixonBranin
)
from problems.benchmark import (
    SMD01, SMD02, SMD03, SMD04, SMD05, SMD06,
    SMD07, SMD08, SMD09, SMD10, SMD11, SMD12,
)
from problems.real_world import (
    HighEntropyAlloySmall, HighEntropyAlloyMedium, HighEntropyAlloyLarge,
    EnergyMarket,
)



def get_problem(
    name: str,
    noise_std: float | list[float] | None = None,
    has_candidates: bool = False,
    num_discretize: int | list[int] | None = None,
    **kwargs
) -> BiLevelBaseProblem:

    REGISTRY = {
        # synthetic
        "GPPriorSamples": GPPriorSamples,
        "BraninGoldstein": BraninGoldstein,
        "CamelBranin": CamelBranin,
        "DixonBranin": DixonBranin,
        # benchmark
        "SMD01": SMD01,
        "SMD02": SMD02,
        "SMD03": SMD03,
        "SMD04": SMD04,
        "SMD05": SMD05,
        "SMD06": SMD06,
        "SMD07": SMD07,
        "SMD08": SMD08,
        "SMD09": SMD09,
        "SMD10": SMD10,
        "SMD11": SMD11,
        "SMD12": SMD12,
        # real-world
        "HighEntropyAlloySmall": HighEntropyAlloySmall,
        "HighEntropyAlloyMedium": HighEntropyAlloyMedium,
        "HighEntropyAlloyLarge": HighEntropyAlloyLarge,
        "EnergyMarket": EnergyMarket,
    }
    problem_cls = REGISTRY[name]
    sig = inspect.signature(problem_cls.__init__).parameters.values()
    args_list = [p.name for p in sig if p.name != "self"]
    params = {
        "noise_std": noise_std,
        "has_candidates": has_candidates,
        "num_discretize": num_discretize,
    }
    for key, value in kwargs.items():
        if key in args_list:
            params[key] = value
    return problem_cls(**params)

