import foolbox as fb
import foolbox.attacks as fa
import torchattacks

def get_foolbox_model_and_attacks(model, device, norm_mean, norm_std):
    fmodel = fb.PyTorchModel(model, bounds=(0, 1), preprocessing=dict(mean=norm_mean, std=norm_std, axis=-3), device=device)
    attacks = {
        "Client_2_FGSM": fa.FGSM(),
        "Client_3_PGD": fa.LinfPGD(),
        "Client_4_BasicIterative": fa.LinfBasicIterativeAttack(),
        "Client_5_DeepFool": fa.LinfDeepFoolAttack()
    }
    return fmodel, attacks

def get_torchattacks_attacks_for_exp2(model, eps=0.1):
    attacks = {
        "FGSM": torchattacks.FGSM(model, eps=eps),
        "PGD": torchattacks.PGD(model, eps=eps, alpha=0.01, steps=40),
        "BIM": torchattacks.BIM(model, eps=eps, alpha=0.01, steps=10),
        "AutoAttack": torchattacks.AutoAttack(model, norm='Linf', eps=eps, version='standard')
    }
    return attacks

def get_sota_attacks(model):
    attacks = {
        "FGSM": torchattacks.FGSM(model, eps=0.008),
        "PGD": torchattacks.PGD(model, eps=0.01, alpha=0.02, steps=40),
        "BIM": torchattacks.BIM(model, eps=0.03, alpha=0.01, steps=10),
        "DeepFool": torchattacks.DeepFool(model, steps=20),
        "CW": torchattacks.CW(model, c=2, kappa=2, steps=500, lr=0.01)
    }
    return attacks