from .trainer import Trainer
from .ep_trainer import EPTrainer
from .attn_trainer import AttnTrainer
from .attn_ep_trainer import AttnEPTrainer
from .clean_trainer import CleanTrainer


TRAINERS = {
    "clean": CleanTrainer,
    "base": Trainer,
    "badnets": Trainer,
    "addsent": Trainer,
    "trojanlm": Trainer,
    "ep": EPTrainer,
    "synbkd": Trainer,
    "stylebkd": Trainer,
    "attn": AttnTrainer,
    "attn_badnets": AttnTrainer,
    "attn_addsent": AttnTrainer,
    "attn_ep": AttnEPTrainer,
    "attn_stylebkd": AttnTrainer,
    "attn_synbkd": AttnTrainer

}



def load_trainer(config):
    return TRAINERS[config["name"].lower()](**config)
