import warnings
import os
import gdown
from typing import Optional, Union, Tuple, Dict
from pathlib import Path
import torch
from torch import Tensor
from torch import nn
import torchattacks
import torchvision.transforms as transforms
import torchvision

# import timm
import timm_copy as timm
from collections import OrderedDict

from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
from robustbench.model_zoo import model_dicts as all_models

import models.resnet as resnet
from models.wide_resnet import WideResNet
from models.robustbench_dm_wide_resnet import DMWideResNet, CIFAR100_MEAN, CIFAR100_STD, CIFAR10_MEAN, CIFAR10_STD, Swish, DMPreActResNet
import models.robustbench_wide_resnet as robust_WideResNet
import attack_shap

def download_gdrive_new(gdrive_id, fname_save):
    """Download checkpoints with gdown, see https://github.com/wkentaro/gdown."""
    
    if isinstance(fname_save, Path):
        fname_save = str(fname_save)
    print(f'Downloading {fname_save} (gdrive_id={gdrive_id}).')
    gdown.download(id=gdrive_id, output=fname_save)

def robustbench_weight_download(model_name, network, dataset, threat_model, model_dir):
    dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
    threat_model = ThreatModel(threat_model).value.replace('_3d', '')
    threat_model_: ThreatModel = ThreatModel(threat_model)

    lower_model_name = model_name.lower().replace('-', '_')
    model_dir_ = f'{model_dir}/{dataset_.value}/{threat_model_.value}'
    model_path = f'{model_dir_}/{model_name}.pt'

    models = all_models[dataset_][threat_model_]
    model = models[model_name]['model']()
    if not os.path.exists(model_dir_):
        os.makedirs(model_dir_)
    if not os.path.isfile(model_path):
        download_gdrive_new(models[model_name]['gdrive_id'], model_path)

def rm_substr_from_state_dict(state_dict, substr):
    new_state_dict = OrderedDict()
    for key in state_dict.keys():
        if substr in key:  # to delete prefix 'module.' if it exists
            new_key = key[len(substr):]
            new_state_dict[new_key] = state_dict[key]
        else:
            new_state_dict[key] = state_dict[key]
    return new_state_dict

def add_substr_to_state_dict(state_dict, substr):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[substr + k] = v
    return new_state_dict

def _safe_load_state_dict(model: nn.Module, model_name: str,
                          state_dict: Dict[str, torch.Tensor],
                          dataset_: BenchmarkDataset) -> nn.Module:
    known_failing_models = {
        "Andriushchenko2020Understanding", "Augustin2020Adversarial",
        "Engstrom2019Robustness", "Pang2020Boosting", "Rice2020Overfitting",
        "Rony2019Decoupling", "Wong2020Fast", "Hendrycks2020AugMix_WRN",
        "Hendrycks2020AugMix_ResNeXt",
        "Kireev2021Effectiveness_Gauss50percent",
        "Kireev2021Effectiveness_AugMixNoJSD", "Kireev2021Effectiveness_RLAT",
        "Kireev2021Effectiveness_RLATAugMixNoJSD",
        "Kireev2021Effectiveness_RLATAugMixNoJSD",
        "Kireev2021Effectiveness_RLATAugMix", "Chen2020Efficient",
        "Wu2020Adversarial", "Augustin2020Adversarial_34_10",
        "Augustin2020Adversarial_34_10_extra", "Diffenderfer2021Winning_LRR",
        "Diffenderfer2021Winning_LRR_CARD_Deck",
        "Diffenderfer2021Winning_Binary",
        "Diffenderfer2021Winning_Binary_CARD_Deck",
        "Huang2022Revisiting_WRN-A4",
        "Bai2024MixedNUTS",
    }

    failure_messages = [
        'Missing key(s) in state_dict: "mu", "sigma".',
        'Unexpected key(s) in state_dict: "model_preact_hl1.1.weight"',
        'Missing key(s) in state_dict: "normalize.mean", "normalize.std"',
        'Unexpected key(s) in state_dict: "conv1.scores"',
        'Missing key(s) in state_dict: "mean", "std".',
    ]

    try:
        msg = model.load_state_dict(state_dict, strict=True)
    except RuntimeError as e:
        #with open('./log_new_models.txt', 'a') as f:
        #    f.write(str(e))
        if (model_name in known_failing_models
                or dataset_ == BenchmarkDataset.imagenet) and any(
                    [msg in str(e) for msg in failure_messages]):
            msg = model.load_state_dict(state_dict, strict=False)
        else:
            raise e
    print(msg)
    return model

class ImageNormalizer(nn.Module):

    def __init__(self, mean: Tuple[float, float, float],
                 std: Tuple[float, float, float]) -> None:
        super(ImageNormalizer, self).__init__()

        self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1))
        self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1))

    def forward(self, input: Tensor) -> Tensor:
        return (input - self.mean) / self.std

    def __repr__(self):
        return f'ImageNormalizer(mean={self.mean.squeeze()}, std={self.std.squeeze()})'  # type: ignore
    
def normalize_model(model: nn.Module, mean: Tuple[float, float, float],
                    std: Tuple[float, float, float]) -> nn.Module:
    layers = OrderedDict([('normalize', ImageNormalizer(mean, std)),
                          ('model', model)])
    return nn.Sequential(layers)

def load_model(model_name: str,
               model_dir: Union[str, Path] = './pretrained_weights',
               dataset: Union[str,BenchmarkDataset] = BenchmarkDataset.cifar_10,
               threat_model: Union[str, ThreatModel] = ThreatModel.Linf,
               custom_checkpoint: str = "",
               norm: Optional[str] = None,
               num_classes: int = 10) -> nn.Module:
    dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
    if norm is None:
        # since there is only `corruptions` folder for models in the Model Zoo
        threat_model = ThreatModel(threat_model).value.replace('_3d', '')
        threat_model_: ThreatModel = ThreatModel(threat_model)
    else:
        threat_model_ = ThreatModel(norm)
        warnings.warn(
            "`norm` has been deprecated and will be removed in a future version.",
            DeprecationWarning)
        
    lower_model_name = model_name.lower().replace('-', '_')
    timm_model_name = f"{lower_model_name}_{dataset_.value.lower()}_{threat_model_.value.lower()}"
    
    if timm.is_model(timm_model_name):
        return timm.create_model(timm_model_name,
                                 num_classes=num_classes,
                                 pretrained=True,
                                 checkpoint_path=custom_checkpoint).eval()

    if model_name in ['baseline', 'mixup', 'cutmix']:
        model_dir_ = Path(model_dir) / dataset_.value / f'Aug'
        model_path = model_dir_ / f'{model_name}_model_best.pth.tar'
    else:
        model_dir_ = Path(model_dir) / dataset_.value / threat_model_.value
        model_path = model_dir_ / f'{model_name}.pt'

    # mu, sigma
    if dataset.startswith('cifar'):
        IMG_MEAN=[x / 255.0 for x in [125.3, 123.0, 113.9]]
        IMG_STD=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    elif dataset == 'imagenet':
        mu = (0.485, 0.456, 0.406)
        sigma = (0.229, 0.224, 0.225)
    else:
        raise NotImplementedError
    if dataset == 'imagenet':
        models = OrderedDict([
            ('Wong2020Fast', {  # requires resolution 288 x 288
                'model': lambda: normalize_model(resnet.resnet50(), mu, sigma),
                'gdrive_id': '1deM2ZNS5tf3S_-eRURJi-IlvUL8WJQ_w',
                'preprocessing': 'Crop288'
            }),
            ('Engstrom2019Robustness', {
                'model': lambda: normalize_model(resnet.resnet50(), mu, sigma),
                'gdrive_id': '1T2Fvi1eCJTeAOEzrH_4TAIwO8HTOYVyn',
                'preprocessing': 'Res256Crop224',
            }),
            ('Salman2020Do_R50', {
                'model': lambda: normalize_model(resnet.resnet50(), mu, sigma),
                'gdrive_id': '1TmT5oGa1UvVjM3d-XeSj_XmKqBNRUg8r',
                'preprocessing': 'Res256Crop224'
            }),
            ('baseline', {
                'model': lambda: normalize_model(resnet.resnet50(), mu, sigma),
                'gdrive_id': '',
                'preprocessing': 'Res256Crop224'
            }),
            ('mixup', {
                'model': lambda: normalize_model(resnet.resnet50(), mu, sigma),
                'gdrive_id': '',
                'preprocessing': 'Res256Crop224'
            }),
            ('cutmix', {
                'model': lambda: normalize_model(resnet.resnet50(), mu, sigma),
                'gdrive_id': '',
                'preprocessing': 'Res256Crop224'
            }),
            ('Liu2023Comprehensive_Swin-B', {
                'model': lambda: normalize_model(timm.create_model(
                    'swin_base_patch4_window7_224', pretrained=False), mu, sigma),
                'gdrive_id': '1-4mtxQCkThJUVdS3wvQ6NnmMZuySqR3c',
                'preprocessing': 'BicubicRes256Crop224'
            }),
            ('Liu2023Comprehensive_Swin-L', {
                'model': lambda: normalize_model(timm.create_model(
                    'swin_large_patch4_window7_224', pretrained=False), mu, sigma),
                'gdrive_id': '1-57sQfcrsDsslfDR18nRD7FnpQmsSBk7',
                'preprocessing': 'BicubicRes256Crop224'
            }),
            ('Mo2022When_Swin-B', {
                'model': lambda: normalize_model(timm.create_model(
                    'swin_base_patch4_window7_224', pretrained=False,
                    ), mu, sigma),
                'gdrive_id': '1-SXi4Z2X6Zo_j8EO4slJcBMXNej8fKUd',
                'preprocessing': 'Res224',
            }),
            ('Mo2022When_ViT-B', {
                'model': lambda: normalize_model(timm.create_model(
                    'vit_base_patch16_224', pretrained=False,
                    ), mu, sigma),
                'gdrive_id': '1-dUFdvDBflqMsMLjZv3wlPJTm-Jm7net',
                'preprocessing': 'Res224',
            }),
        ])
    elif dataset == 'cifar10':
        if model_dir.split('/')[-1] == 'wrn28-10':
            dep = 28
        elif model_dir.split('/')[-1] == 'wrn34-10':
            dep = 34

        models = OrderedDict([
            ('baseline', {
                'model': lambda: WideResNet(depth=dep, num_classes=num_classes, widen_factor=10, dropRate=0.3, mean=IMG_MEAN, std=IMG_STD),
                'gdrive_id': '',
            }),
            ('cutmix', {
                'model': lambda: WideResNet(depth=dep, num_classes=num_classes, widen_factor=10, dropRate=0.3, mean=IMG_MEAN, std=IMG_STD),
                'gdrive_id': '',
            }),
            ('mixup', {
                'model': lambda: WideResNet(depth=dep, num_classes=num_classes, widen_factor=10, dropRate=0.3, mean=IMG_MEAN, std=IMG_STD),
                'gdrive_id': '',
            }),
            ('Wang2020Improving', {
                'model': lambda: robust_WideResNet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Wu2020Adversarial_extra', {
                'model': lambda: robust_WideResNet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Wang2023Better_WRN-28-10', {
                'model': lambda: DMWideResNet(num_classes=10,
                                              depth=28, 
                                              width=10, 
                                              activation_fn=nn.SiLU,
                                              mean=CIFAR10_MEAN,
                                              std=CIFAR10_STD),
                'gdrive_id': '',
            }),
            ('Xu2023Exploring_WRN-28-10', {
                'model': lambda: DMWideResNet(num_classes=10,
                                              depth=28,
                                              width=10,
                                              activation_fn=nn.SiLU, 
                                              mean=CIFAR10_MEAN,
                                              std=CIFAR10_STD),
                'gdrive_id': '',
            }),
            ('Pang2022Robustness_WRN28_10', {
                'model': lambda: DMWideResNet(num_classes=10,
                                              depth=28,
                                              width=10,
                                              activation_fn=Swish, 
                                              mean=CIFAR10_MEAN,
                                              std=CIFAR10_STD),
                'gdrive_id': '',
            }),
            ('Rade2021Helper_ddpm', {
                'model': lambda: DMWideResNet(num_classes=10,
                                              depth=28,
                                              width=10,
                                              activation_fn=Swish, 
                                              mean=CIFAR10_MEAN,
                                              std=CIFAR10_STD),
                'gdrive_id': '',
            }),
            ('Sridhar2021Robust', {
                'model': lambda: robust_WideResNet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Zhang2020Geometry', {
                'model': lambda: robust_WideResNet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Gowal2021Improving_28_10_ddpm_100m', {
                'model': lambda: DMWideResNet(num_classes=10,
                                              depth=28,
                                              width=10,
                                              activation_fn=Swish, 
                                              mean=CIFAR10_MEAN,
                                              std=CIFAR10_STD),
                'gdrive_id': '',
            }),
            ('Sehwag2020Hydra', {
                'model': lambda: robust_WideResNet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Carmon2019Unlabeled', {
                'model': lambda: robust_WideResNet.WideResNet(depth=28, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Zhang2019Theoretically', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Rade2021Helper_extra', {
                'model': lambda: DMWideResNet(num_classes=10,
                                              depth=34,
                                              width=10,
                                              activation_fn=Swish, 
                                              mean=CIFAR10_MEAN,
                                              std=CIFAR10_STD),
                'gdrive_id': '',
            }),
            ('Chen2024Data_WRN_34_10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10),
                'gdrive_id': '',
            }),
            ('Sehwag2021Proxy', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=False),
                'gdrive_id': '',
            }),
            ('Addepalli2021Towards_WRN34', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Addepalli2022Efficient_WRN_34_10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10),
                'gdrive_id': '',
            }),
            ('Cui2020Learnable_34_10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Zhang2020Attacks', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Huang2020Self', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Wu2020Adversarial', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10),
                'gdrive_id': '',
            }),
            ('Zhang2019You', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            
        ])
    elif dataset == 'cifar100':
        if model_dir.split('/')[-1] == 'wrn28-10':
            dep = 28
        elif model_dir.split('/')[-1] == 'wrn34-10':
            dep = 34
        models = OrderedDict([
            ('baseline', {
                'model': lambda: WideResNet(depth=dep, num_classes=num_classes, widen_factor=10, dropRate=0.3, mean=IMG_MEAN, std=IMG_STD),
                'gdrive_id': '',
            }),
            ('cutmix', {
                'model': lambda: WideResNet(depth=dep, num_classes=num_classes, widen_factor=10, dropRate=0.3, mean=IMG_MEAN, std=IMG_STD),
                'gdrive_id': '',
            }),
            ('mixup', {
                'model': lambda: WideResNet(depth=dep, num_classes=num_classes, widen_factor=10, dropRate=0.3, mean=IMG_MEAN, std=IMG_STD),
                'gdrive_id': '',
            }),
            ('Cui2023Decoupled_WRN-28-10', {
                'model': lambda: DMWideResNet(num_classes=num_classes,
                                              depth=28,
                                              width=10,
                                              activation_fn=nn.SiLU, 
                                              mean=CIFAR100_MEAN,
                                              std=CIFAR100_STD),
                'gdrive_id': '',
            }),
            ('Wang2023Better_WRN-28-10', {
                'model': lambda: DMWideResNet(num_classes=num_classes,
                                              depth=28,
                                              width=10,
                                              activation_fn=nn.SiLU, 
                                              mean=CIFAR100_MEAN,
                                              std=CIFAR100_STD),
                'gdrive_id': '',
            }),
            ('Rebuffi2021Fixing_28_10_cutmix_ddpm', {
                'model': lambda: DMWideResNet(num_classes=num_classes,
                                              depth=28,
                                              width=10,
                                              activation_fn=Swish, 
                                              mean=CIFAR100_MEAN,
                                              std=CIFAR100_STD),
                'gdrive_id': '',
            }),
            ('Pang2022Robustness_WRN28_10', {
                'model': lambda: DMWideResNet(num_classes=num_classes,
                                              depth=28,
                                              width=10,
                                              activation_fn=Swish, 
                                              mean=CIFAR100_MEAN,
                                              std=CIFAR100_STD),
                'gdrive_id': '',
            }),
            ('Cui2023Decoupled_WRN-34-10_autoaug', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=False),
                'gdrive_id': '',
            }),
            ('Addepalli2022Efficient_WRN_34_10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=False),
                'gdrive_id': '',
            }),
            ('Cui2023Decoupled_WRN-34-10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=False),
                'gdrive_id': '',
            }),
            ('Cui2020Learnable_34_10_LBGAT9_eps_8_255', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=False),
                'gdrive_id': '',
            }),
            ('Sehwag2021Proxy', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=False),
                'gdrive_id': '',
            }),
            ('Jia2022LAS-AT_34_10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Chen2021LTD_WRN34_10', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Addepalli2021Towards_WRN34', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            ('Cui2020Learnable_34_10_LBGAT6', {
                'model': lambda: robust_WideResNet.WideResNet(depth=34, num_classes=num_classes, widen_factor=10, sub_block1=True),
                'gdrive_id': '',
            }),
            
        ])

    if models[model_name]['gdrive_id'] is None:
        raise ValueError(
            f"Model `{model_name}` nor {timm_model_name} aren't a timm model and has no `gdrive_id` specified."
        )

    if not isinstance(models[model_name]['gdrive_id'], list):
        model = models[model_name]['model']()
        if dataset_ == BenchmarkDataset.imagenet and 'Standard' in model_name:
            return model.eval()

        if not os.path.exists(model_dir_):
            os.makedirs(model_dir_)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

        if 'Kireev2021Effectiveness' in model_name or model_name == 'Andriushchenko2020Understanding':
            checkpoint = checkpoint[
                'last']  # we take the last model (choices: 'last', 'best')
        try:
            # needed for the model of `Carmon2019Unlabeled`
            state_dict = rm_substr_from_state_dict(checkpoint['state_dict'],
                                                   'module.')
            # needed for the model of `Chen2020Efficient`
            state_dict = rm_substr_from_state_dict(state_dict, 'model.')
        except:
            state_dict = rm_substr_from_state_dict(checkpoint, 'module.')
            state_dict = rm_substr_from_state_dict(state_dict, 'model.')

        if dataset_ == BenchmarkDataset.imagenet:
            # Adapt checkpoint to the model defition in newer versions of timm.
            if model_name in [
                'Liu2023Comprehensive_Swin-B',
                'Liu2023Comprehensive_Swin-L',
                'Mo2022When_Swin-B',
                ]:
                try:
                    from timm.models.swin_transformer import checkpoint_filter_fn
                    state_dict = checkpoint_filter_fn(state_dict, model.model)
                except:
                    pass

            # Some models need input normalization, which is added as extra layer.
            if model_name not in [
                'Singh2023Revisiting_ConvNeXt-T-ConvStem',
                'Singh2023Revisiting_ViT-B-ConvStem',
                'Singh2023Revisiting_ConvNeXt-S-ConvStem',
                'Singh2023Revisiting_ConvNeXt-B-ConvStem',
                'Singh2023Revisiting_ConvNeXt-L-ConvStem',
                'Peng2023Robust',
                'Chen2024Data_WRN_50_2',
                ]:
                state_dict = add_substr_to_state_dict(state_dict, 'model.')

        model = _safe_load_state_dict(model, model_name, state_dict, dataset_)

        return model.eval()
        
def attack_loader(args, net):
    # Gradient Clamping based Attack (White-box)
    if args.attack == "fgsm":
        return torchattacks.FGSM(model=net, eps=args.eps)

    elif args.attack == "bim":
        return torchattacks.BIM(model=net, eps=args.eps, alpha=1/255)

    elif args.attack == "pgd":
        return torchattacks.PGD(model=net, eps=args.eps,
                                alpha=args.eps/args.steps*2.3, steps=args.steps, random_start=True)

    elif args.attack == "cw":
        return torchattacks.CW(model=net, c=0.1, lr=0.1, steps=200)

    elif args.attack == "auto":
        return torchattacks.APGD(model=net, eps=args.eps)

    elif args.attack == "fab":
        return torchattacks.FAB(model=net, eps=args.eps, n_classes=args.n_classes)

    # Black-box

    elif args.attack == "square":
        return torchattacks.Square(model=net, eps=args.eps)

    elif args.attack == "spsa":
        return torchattacks.SPSA(model=net, eps=args.eps)

    elif args.attack == "onepixel":
        return torchattacks.OnePixel(model=net)
    
    ###
    elif args.attack == "autoattack":
        return torchattacks.AutoAttack(model=net)
    
    ### shap attack
    elif args.attack == "fgsm_shap":
        return attack_shap.FGSM(model=net, eps=args.eps)
    elif args.attack == "cw_shap":
        return attack_shap.CW(model=net, c=0.1, lr=0.1, steps=200, att_alpha=args.att_alpha)
    elif args.attack == "cw_tracking_shap":
        return attack_shap.CW_tracking(model=net, c=0.1, lr=0.1, steps=200, att_alpha=args.att_alpha)
    elif args.attack == "cw_v2_shap":
        return attack_shap.CW_V2(model=net, c=0.1, lr=0.1, steps=200)
    elif args.attack == "pgd_shap":
        return attack_shap.PGD(model=net, eps=args.eps,
                                alpha=args.eps/args.steps*2.3, steps=args.steps, random_start=True, att_alpha=args.att_alpha)
    elif args.attack == "auto_pgd_shap":
        return attack_shap.APGD(model=net, eps=args.eps, att_alpha=args.att_alpha)
    
def data_loader(dataset, data_root, split = 'train', batch_size=1, num_workers=4, shuffle=False, mode=None, network=None, att=None):
    if dataset == 'imagenet':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            ])
    
        data = torchvision.datasets.ImageFolder(data_root, transform=transform)

    elif dataset in ['cifar10', 'cifar100']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            ])

        if dataset == 'cifar10':
            if split == 'val':
                data = torchvision.datasets.CIFAR10(data_root, train=False, transform=transform)
            if split == 'train':
                data = torchvision.datasets.CIFAR10(data_root, train=True, download=True, transform=transform)
        elif dataset == 'cifar100':
            if split == 'val':
                data = torchvision.datasets.CIFAR100(data_root, train=False, transform=transform)
            if split == 'train':
                data = torchvision.datasets.CIFAR100(data_root, train=True, download=True, transform=transform)

    loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=num_workers)
    
    return data, loader