
import inspect

from acquisitions.base import BiLevelAcquisition
from acquisitions.rand import RandomSampling
from acquisitions.bilbo import BiLevelUpperConfidenceBound
from acquisitions.thompson import BiLevelThompsonSampling
from acquisitions.mes import BiLevelMaxValueEntropySearch
from acquisitions.jes import BiLevelJointEntropySearch
from utils import RFFModelList



def get_acqf(
    name: str,
    model_Y_upper: RFFModelList,
    model_Y_lower: RFFModelList,
    model_C_upper: RFFModelList | None = None,
    model_C_lower: RFFModelList | None = None,
    **kwargs
) -> BiLevelAcquisition:

    registry = {
        "RandomSampling": RandomSampling,
        "BiLevelUpperConfidenceBound": BiLevelUpperConfidenceBound,
        "BiLevelThompsonSampling": BiLevelThompsonSampling,
        "BiLevelMaxValueEntropySearch": BiLevelMaxValueEntropySearch,
        "BiLevelJointEntropySearch": BiLevelJointEntropySearch,
    }
    acqf_cls = registry[name]
    sig = inspect.signature(acqf_cls.__init__).parameters.values()
    args_list = [p.name for p in sig if p.name != "self"]
    params = {
        "model_Y_upper": model_Y_upper,
        "model_Y_lower": model_Y_lower,
        "model_C_upper": model_C_upper,
        "model_C_lower": model_C_lower,
    }
    for key, value in kwargs.items():
        if key in args_list:
            params[key] = value
    return acqf_cls(**params)