# baselines
import importlib
from .model_utils import get_template, load_model_and_tokenizer, load_vllm_model, _init_ray

_method_mapping = {
    "DirectRequest": "baselines.direct_request",
    "GCG": "baselines.gcg",
    "EnsembleGCG": "baselines.gcg_ensemble",
    "HumanJailbreaks": "baselines.human_jailbreaks",
    "ZeroShot": "baselines.zeroshot",
    "FewShot": "baselines.fewshot",
    "PEZ": "baselines.pez",
    "GBDA": "baselines.gbda",
    "PAIR": "baselines.pair",
    "TAP": "baselines.tap",
    "AutoDAN": "baselines.autodan",
    "AutoPrompt": "baselines.autoprompt",
    "UAT": "baselines.uat",
    'PAP': "baselines.pap",
    "MultiModalPGD": "baselines.multimodalpgd",
    "MultiModalPGDBlankImage": "baselines.multimodalpgd",
    "MultiModalPGDPatch": "baselines.multimodalpgd",
    "MultiModalRenderText": "baselines.multimodalrendertext",
    "MultiModalDirectRequest": "baselines.multimodaldirectrequest",
    "GPTFuzz": "baselines.gptfuzz",
    "ArtPrompt": "baselines.artprompt",
}


def get_method_class(method):
    """
    Returns the method class given the method name. This is used to access static methods.
    """
    if method not in _method_mapping:
        raise ValueError(f"Can not find method {method}")
    module_path = _method_mapping[method]
    module = importlib.import_module(module_path)
    method_class = getattr(module, method)  # Class name assumed to be the same as method name
    return method_class


def init_method(method_class, method_config):
    if method_class.use_ray:
        _init_ray(num_cpus=8)
    print("="*80)
    for k, v in method_config.items():
        print(f"{k}: {v}")
        print("="*20)
    print("\n\n")
    output = method_class(**method_config)
    return output

