from .poisoner import Poisoner
from .badnets_poisoner import BadNetsPoisoner
from .ep_poisoner import EPPoisoner
from .sos_poisoner import SOSPoisoner
from .synbkd_poisoner import SynBkdPoisoner
from .stylebkd_poisoner import StyleBkdPoisoner
from .addsent_poisoner import AddSentPoisoner
from .trojanlm_poisoner import TrojanLMPoisoner
from .neuba_poisoner import NeuBAPoisoner
from .por_poisoner import PORPoisoner
from .lwp_poisoner import LWPPoisoner
from .attn_poisoner import AttnPoisoner
from .attn_ep_poisoner import AttnEPPoisoner
from .attn_stylebkd_poisoner import AttnStyleBkdPoisoner
from .attn_synbkd_poisoner import AttnSynBkdPoisoner
from .clean_poisoner import CleanPoisoner



POISONERS = {
    "clean": CleanPoisoner,
    "base": Poisoner,
    "badnets": BadNetsPoisoner,
    "addsent": AddSentPoisoner,
    "ripples": BadNetsPoisoner, # For RIPPLE, poisoner is the same with BadNets
    "ep": EPPoisoner,
    "sos": SOSPoisoner,
    "synbkd": SynBkdPoisoner,
    "stylebkd": StyleBkdPoisoner,
    "trojanlm": TrojanLMPoisoner,
    "neuba": NeuBAPoisoner,
    "por": PORPoisoner,
    "lwp": LWPPoisoner,
    "lws": BadNetsPoisoner,
    "attn": AttnPoisoner,
    "attn_badnets": AttnPoisoner,
    "attn_addsent": AttnPoisoner,
    "attn_ep": AttnEPPoisoner,
    "attn_stylebkd": AttnStyleBkdPoisoner,
    "attn_synbkd": AttnSynBkdPoisoner

}

def load_poisoner(config):
    return POISONERS[config["name"].lower()](**config)
