# baselines
from .autodan import AutoDan
from .autodan_adaptive import AutoDan_adaptive
from .autoprompt import AutoPrompt
from .baseline import RedTeamingMethod
from .direct_request import DirectRequest
from .fewshot import FewShot
from .gbda import GBDA
from .gcg import GCG
from .gcg_adaptive import GCG_adaptive
from .gcg_ensemble import EnsembleGCG
from .human_jailbreaks import HumanJailbreaks
from .model_utils import _STR_DTYPE_TO_TORCH_DTYPE, _init_ray, load_model_and_tokenizer
from .multimodaldirectrequest import MultiModalDirectRequest
from .multimodalpgd import MultiModalPGD, MultiModalPGDBlankImage, MultiModalPGDPatch
from .multimodalrendertext import MultiModalRenderText
from .pair import PAIR
from .pair_adaptive import PAIR_adaptive
from .pap import PAP
from .pez import PEZ
from .pez_adaptive import PEZ_adaptive
from .random_search import RS
from .tap import TAP
from .tap_adaptive import TAP_adaptive
from .uat import UAT
from .zeroshot import ZeroShot

_method_mapping = {
    'DirectRequest': DirectRequest,
    "GCG": GCG,
    "GCG_adaptive": GCG_adaptive,
    "RS": RS,
    "EnsembleGCG": EnsembleGCG,
    "HumanJailbreaks": HumanJailbreaks,
    "ZeroShot": ZeroShot,
    "FewShot": FewShot,
    "PEZ": PEZ,
    "PEZ_adaptive": PEZ_adaptive,
    "GBDA": GBDA,
    "PAIR": PAIR,
    "PAIR_adaptive": PAIR_adaptive,
    "TAP": TAP,
    "TAP_adaptive": TAP_adaptive,
    "AutoDan": AutoDan,
    "AutoDan_adaptive": AutoDan_adaptive,
    "AutoPrompt": AutoPrompt,
    'UAT': UAT,
    'PAP': PAP,
    'MultiModalPGD': MultiModalPGD,
    'MultiModalPGDBlankImage': MultiModalPGDBlankImage,
    'MultiModalPGDPatch': MultiModalPGDPatch,
    'MultiModalRenderText': MultiModalRenderText,
    'MultiModalDirectRequest': MultiModalDirectRequest,
}

def get_method_class(method):
    """
    Returns the method class given the method name. This is used to access static methods.
    """
    if method not in _method_mapping:
        raise ValueError(f"Can not find method {method}")
    return _method_mapping[method]

def init_method(method, method_config):
    if method not in _method_mapping:
        raise ValueError(f"Can not find method {method}")
    method_class = _method_mapping[method]

    if method_class.use_ray:
        _init_ray(num_cpus=32)

    return method_class(**method_config)