# This file is the implementation of FTSAM defense.
# FTSAM: Enhancing Fine-Tuning Based Backdoor Defense with Sharpness-Aware Minimization [ICCV, 2023] (https://arxiv.org/abs/2304.11823)

# Basic structure:
# 1. load the backdoored attack data and backdoored test data
# 2. load the backdoored model
# 3. for each round sample a clean batch from given clean subset:
#   a. do weight perturb to maximize L constrained by rho
#   b. do outer minimization
# 4. test the result and get ASR, ACC, RC

import os
import random
from tqdm import tqdm
import contextlib

import torch
import torch.nn.functional as F
from torch.utils.data import Subset, Dataset, DataLoader
import torch.nn as nn
from torch.distributed import ReduceOp
from torch.nn.modules.batchnorm import _BatchNorm

from .base import Base

def smooth_crossentropy(pred, gold, smoothing=0.1):
    n_class = pred.size(1)

    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)

    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)

class ProportionScheduler:
    def __init__(self, pytorch_lr_scheduler, max_lr, min_lr, max_value, min_value):
        """
        This scheduler outputs a value that evolves proportional to pytorch_lr_scheduler, e.g.
        (value - min_value) / (max_value - min_value) = (lr - min_lr) / (max_lr - min_lr)
        """
        self.t = 0    
        self.pytorch_lr_scheduler = pytorch_lr_scheduler
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.max_value = max_value
        self.min_value = min_value
        
        assert (max_lr > min_lr) or ((max_lr==min_lr) and (max_value==min_value)), "Current scheduler for `value` is scheduled to evolve proportionally to `lr`," \
        "e.g. `(lr - min_lr) / (max_lr - min_lr) = (value - min_value) / (max_value - min_value)`. Please check `max_lr >= min_lr` and `max_value >= min_value`;" \
        "if `max_lr==min_lr` hence `lr` is constant with step, please set 'max_value == min_value' so 'value' is constant with step."
    
        assert max_value >= min_value
        
        self.step() # take 1 step during initialization to get self._last_lr
    
    def lr(self):
        return self._last_lr[0]
                
    def step(self):
        self.t += 1
        if hasattr(self.pytorch_lr_scheduler, "_last_lr"):
            lr = self.pytorch_lr_scheduler._last_lr[0]
        else:
            lr = self.pytorch_lr_scheduler.optimizer.param_groups[0]['lr']
            
        if self.max_lr > self.min_lr:
            value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / (self.max_lr - self.min_lr)
        else:
            value = self.max_value
        
        self._last_lr = [value]
        return value


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, model, sam_alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', **kwargs):
        defaults = dict(adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.model = model
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.adaptive = adaptive
        self.rho_scheduler = rho_scheduler
        self.perturb_eps = perturb_eps
        self.alpha = sam_alpha
        
        # initialize self.rho_t
        self.update_rho_t()
        
        # set up reduction for gradient across workers
        if grad_reduce.lower() == 'mean':
            if hasattr(ReduceOp, 'AVG'):
                self.grad_reduce = ReduceOp.AVG
                self.manual_average = False
            else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
                self.grad_reduce = ReduceOp.SUM
                self.manual_average = True
        elif grad_reduce.lower() == 'sum':
            self.grad_reduce = ReduceOp.SUM
            self.manual_average = False
        else:
            raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
    
    def disable_running_stats(self, model):
        def _disable(module):
            if isinstance(module, _BatchNorm):
                module.backup_momentum = module.momentum
                module.momentum = 0

        model.apply(_disable)

    def enable_running_stats(self, model):
        def _enable(module):
            if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
                module.momentum = module.backup_momentum

        model.apply(_enable)
    
    @torch.no_grad()
    def update_rho_t(self):
        self.rho_t = self.rho_scheduler.step()
        return self.rho_t

    @torch.no_grad()
    def perturb_weights(self, rho=0.0):
        grad_norm = self._grad_norm( weight_adaptive = self.adaptive )
        for group in self.param_groups:
            scale = rho / (grad_norm + self.perturb_eps)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_g"] = p.grad.data.clone()
                e_w = p.grad * scale.to(p)
                if self.adaptive:
                    e_w *= torch.pow(p, 2)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]['e_w'] = e_w
                
    @torch.no_grad()
    def unperturb(self):
        for group in self.param_groups:
            for p in group['params']:
                if 'e_w' in self.state[p].keys():
                    p.data.sub_(self.state[p]['e_w'])

    @torch.no_grad()
    def gradient_decompose(self, alpha=0.0):
        # calculate inner product
        inner_prod = 0.0
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                inner_prod += torch.sum(
                    self.state[p]['old_g'] * p.grad.data
                )

        # get norm
        new_grad_norm = self._grad_norm()
        old_grad_norm = self._grad_norm(by='old_g')

        # get cosine
        cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps)

        # gradient decomposition
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad.data / (new_grad_norm + self.perturb_eps)
                p.grad.data.add_( vertical, alpha=-alpha)

    @torch.no_grad()
    def _sync_grad(self):
        if torch.distributed.is_initialized(): # synchronize final gardients
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None: continue
                    if self.manual_average:
                        torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
                        world_size = torch.distributed.get_world_size()
                        p.grad.div_(float(world_size))
                    else:
                        torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
        return

    @torch.no_grad()
    def _grad_norm(self, by=None, weight_adaptive=False):
        #shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        if not by:
            norm = torch.norm(
                    torch.stack([
                        ( (torch.abs(p.data) if weight_adaptive else 1.0) *  p.grad).norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        else:
            norm = torch.norm(
                torch.stack([
                    ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
                    for group in self.param_groups for p in group["params"]
                    if p.grad is not None
                ]),
                p=2
            )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
        
    def maybe_no_sync(self):
        if torch.distributed.is_initialized():
            return self.model.no_sync()
        else:
            return contextlib.ExitStack()

    @torch.no_grad()
    def set_closure(self, loss_fn, inputs, targets, **kwargs):
        # create self.forward_backward_func, which is a function such that
        # self.forward_backward_func() automatically performs forward and backward passes.
        # This function does not take any arguments, and the inputs and targets data
        # should be pre-set in the definition of partial-function

        def get_grad():
            self.base_optimizer.zero_grad()
            with torch.enable_grad():
                outputs = self.model(inputs)
                loss = loss_fn(outputs, targets, **kwargs)
            loss_value = loss.data.clone().detach()
            loss.backward()
            return outputs, loss_value

        self.forward_backward_func = get_grad

    @torch.no_grad()
    def step(self, closure=None):

        if closure:
            get_grad = closure
        else:
            get_grad = self.forward_backward_func

        with self.maybe_no_sync():
            # get gradient
            outputs, loss_value = get_grad()

            # perturb weights
            self.perturb_weights(rho=self.rho_t)

            # disable running stats for second pass
            self.disable_running_stats(self.model)

            # get gradient at perturbed weights
            get_grad()

            # decompose and get new update direction
            self.gradient_decompose(self.alpha)

            # unperturb
            self.unperturb()
            
        # synchronize gradients across workers
        self._sync_grad()    

        # update with new directions
        self.base_optimizer.step()

        # enable running stats
        self.enable_running_stats(self.model)

        return outputs, loss_value


class FTSAM(Base):
    """
    Repair a backdoor model via Fine-Tuning Sharpness-Aware Minimization (FTSAM).
    
    Args:
        model (nn.Module): Backdoor model to be repaired.
        loss (nn.Module): Loss for repaired model training.
        poisoned_trainset (type in support list): Poisoned trainset.
        poisoned_testset (types in support_list): Poisoned testset.
        clean_trainset (types in support_list): Clean trainset.
        clean_testset (types in support_list): Clean testset.
        seed (int): Global seed for random numbers. Default: 0.
        deterministic (bool): Sets whether PyTorch operations must use "deterministic" algorithms.
            That is, algorithms which, given the same input, and when run on the same software and hardware,
            always produce the same output. When enabled, operations will use deterministic algorithms when available,
            and if only nondeterministic algorithms are available they will throw a RuntimeError when called. Default: False.
    """
    def __init__(
        self,
        model,
        poisoned_trainset,
        poisoned_testset,
        clean_trainset,
        clean_testset,
        num_classes,
        device: str | torch.device = 'cpu',
        seed: int = 666, 
        deterministic: bool = False,
        loss: nn.Module = nn.CrossEntropyLoss(),
        epochs: int = 100,
        ratio: float = 0.05
    ):
        super().__init__(seed, deterministic)
        
        self.model = model
        
        self.poisoned_trainset = poisoned_trainset
        self.poisoned_testset = poisoned_testset
        self.clean_trainset = clean_trainset
        self.clean_testset = clean_testset
        self.num_classes = num_classes
        
        self.device = device
        self.loss = loss
        self.epochs = epochs
        self.ratio = ratio
        
        self.base_optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, self.model.parameters()), 
            lr=0.01,
            momentum=0.9,
            weight_decay=5e-4
        )
        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.base_optimizer,
            T_max=100,
        )
        self.rho_scheduler = ProportionScheduler(
            pytorch_lr_scheduler=self.lr_scheduler,
            max_lr=0.01,
            min_lr=0.0,
            max_value=2.0,
            min_value=2.0
        )
        
        self.optimizer = SAM(
            params=self.model.parameters(),
            base_optimizer=self.base_optimizer,
            model=self.model,
            sam_alpha=0.0,
            rho_scheduler=self.rho_scheduler,
            adaptive=False,
            perturb_eps=1e-12,
            grad_reduce='mean'
        )
    
    @classmethod
    def index_choicer_by_class(cls, dataset: Dataset, ratio: float, num_classes: int):
        
        dataset_length = len(dataset)
        class_indice_dict = [ [] for _ in range(num_classes) ]
        
        # traverse the dataset, and put the index of each class into the class_indice_dict
        for i in range(dataset_length):
            class_indice_dict[dataset[i][1]].append(i)
        
        indices = []
        
        # randomly choose the index of each class
        for i in range(num_classes):
            random.shuffle(class_indice_dict[i])
            class_indice_dict[i] = class_indice_dict[i][:int(len(class_indice_dict[i]) * ratio)]
            indices.extend(class_indice_dict[i])
        
        return indices
    
    @classmethod
    def eval(cls, model: nn.Module, clean_testset: Dataset, poisoned_testset: Dataset, device: str, batch_size: int=128):
        
        test_model = model.to(device)
        test_model.eval()
        
        clean_loader = DataLoader(clean_testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        poisoned_loader = DataLoader(poisoned_testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        
        def model_output(model, data_loader, device):
            predict_digits = []
            labels = []
            
            with torch.no_grad():
                
                for batch in data_loader:
                    batch_img, batch_label = batch
                    batch_img = batch_img.to(device)
                    batch_label = batch_label.to(device)
                    
                    batch_img = model(batch_img)
                    
                    predict_digits.append(batch_img.cpu())
                    labels.append(batch_label.cpu())
            
            predict_digits = torch.cat(predict_digits, dim=0)
            labels = torch.cat(labels, dim=0)
            
            return predict_digits, labels
        
        # compute CA
        clean_pred_digits, clean_labels = model_output(test_model, clean_loader, device)
        CA = (clean_pred_digits.argmax(dim=1) == clean_labels).sum().item() / clean_labels.size(0)
        
        # Compute ASR 
        
        poisoned_pred_digits, poisoned_labels = model_output(test_model, poisoned_loader, device)
        ASR = (poisoned_pred_digits.argmax(dim=1) == poisoned_labels).sum().item() / poisoned_labels.size(0)
        
        return CA, ASR
        
        
    def train(self, dataset, model, optimizer: SAM, scheduler, batch_size=128) -> nn.Module:
        model.train()
        
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        def loss_fn(preds, targets):
            return smooth_crossentropy(preds, targets, smoothing=0.1).mean()

        for i in tqdm(range(self.epochs), desc=f"in training epochs"):
            for (img, target) in tqdm(data_loader, desc="traverse data loader"):
                img = img.to(self.device)
                target = target.to(self.device)
                
                optimizer.set_closure(loss_fn, img, target)
                preds, _ = optimizer.step()
                
                with torch.no_grad():
                    correct = torch.argmax(preds.data, 1) == target
                    correct = correct.sum()
                    scheduler.step()
                    optimizer.update_rho_t()

        return model
        
    def repair(self):
        
        former_CA, former_ASR = self.eval(self.model, self.clean_testset, self.poisoned_testset, self.device)
        
        print('==========Before FTSAM repairing==========')
        print(f'CA: {former_CA}, ASR: {former_ASR}')
        
        
        clean_set = Subset(self.clean_trainset, self.index_choicer_by_class(self.clean_trainset, self.ratio, self.num_classes))

        self.model = self.train(clean_set, self.model, self.optimizer, self.lr_scheduler)
        
        latter_CA, latter_ASR = self.eval(self.model, self.clean_testset, self.poisoned_testset, self.device)
            
        print('==========After FTSAM repairing==========')
        print(f'CA: {latter_CA}, ASR: {latter_ASR}')
            
    def get_model(self):
        return self.model
            

# This is the raw code of FTSAM, as a reference for the current version of BackdoorBox

# from pprint import  pformat
# import yaml
# import logging
# import time
# from defense.base import defense
# from utils.defense_utils.sam import SAM, ProportionScheduler
# from utils.defense_utils.sam import smooth_crossentropy

# from utils.aggregate_block.train_settings_generate import argparser_criterion, argparser_opt_scheduler
# from utils.trainer_cls import Metric_Aggregator
# from utils.choose_index import choose_index,choose_by_class
# from utils.aggregate_block.fix_random import fix_random
# from utils.aggregate_block.model_trainer_generate import generate_cls_model
# from utils.aggregate_block.dataset_and_transform_generate import get_input_shape, get_num_classes, get_transform
# from utils.save_load_attack import load_attack_result, save_defense_result
# from utils.bd_dataset_v2 import prepro_cls_DatasetBD_v2



# class AverageMeter(object):
#     """Computes and stores the average and current value"""
#     def __init__(self):
#         self.reset()

#     def reset(self):
#         self.val = 0
#         self.avg = 0
#         self.sum = 0
#         self.count = 0

#     def update(self, val, n=1):
#         self.val = val
#         self.sum += val * n
#         self.count += n
#         self.avg = self.sum / self.count

# def accuracy(output, target, topk=(1,)): # output: (256,10); target: (256)
#     """Computes the accuracy over the k top predictions for the specified values of k"""
#     with torch.no_grad():
#         maxk = max(topk) # 5
#         batch_size = target.size(0)

#         _, pred = output.topk(maxk, 1, True, True) # pred: (256,5)
#         pred = pred.t() # (5,256)
#         correct = pred.eq(target.view(1, -1).expand_as(pred)) # (5,256)

#         res = []

#         for k in topk:
#             # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
#             correct_k = torch.flatten(correct[:k]).float().sum(0, keepdim=True)
#             res.append(correct_k.mul_(1.0 / batch_size))
#         return res

# def given_dataloader_test(
#         model,
#         test_dataloader,
#         criterion,
#         non_blocking : bool = False,
#         device = "cpu",
#         verbose : int = 0
# ):
#     model.to(device, non_blocking=non_blocking)
#     model.eval()
#     metrics = {
#         'test_correct': 0,
#         'test_loss_sum_over_batch': 0,
#         'test_total': 0,
#     }
#     criterion = criterion.to(device, non_blocking=non_blocking)

#     if verbose == 1:
#         batch_predict_list, batch_label_list = [], []

#     with torch.no_grad():
#         for batch_idx, (x, target, *additional_info) in enumerate(test_dataloader):
#             x = x.to(device, non_blocking=non_blocking)
#             target = target.to(device, non_blocking=non_blocking)
#             pred = model(x)
#             loss = criterion(pred, target.long())

#             _, predicted = torch.max(pred, -1)
#             correct = predicted.eq(target).sum()

#             if verbose == 1:
#                 batch_predict_list.append(predicted.detach().clone().cpu())
#                 batch_label_list.append(target.detach().clone().cpu())

#             metrics['test_correct'] += correct.item()
#             metrics['test_loss_sum_over_batch'] += loss.item()
#             metrics['test_total'] += target.size(0)

#     metrics['test_loss_avg_over_batch'] = metrics['test_loss_sum_over_batch']/len(test_dataloader)
#     metrics['test_acc'] = metrics['test_correct'] / metrics['test_total']

#     if verbose == 0:
#         return metrics, None, None
#     elif verbose == 1:
#         return metrics, torch.cat(batch_predict_list), torch.cat(batch_label_list)

# class dsam(defense):

#     def __init__(self,args):
#         with open(args.yaml_path, 'r') as f:
#             defaults = yaml.safe_load(f)

#         defaults.update({k:v for k,v in args.__dict__.items() if v is not None})

#         args.__dict__ = defaults

#         args.terminal_info = sys.argv

#         args.num_classes = get_num_classes(args.dataset)
#         args.input_height, args.input_width, args.input_channel = get_input_shape(args.dataset)
#         args.img_size = (args.input_height, args.input_width, args.input_channel)
#         args.dataset_path = f"{args.dataset_path}/{args.dataset}"

#         self.args = args

#         if 'result_file' in args.__dict__ :
#             if args.result_file is not None:
#                 self.set_result(args.result_file)

#     def add_arguments(parser):
#         parser.add_argument('--device', type=str, help='cuda, cpu')
#         parser.add_argument("-pm","--pin_memory", type=lambda x: str(x) in ['True', 'true', '1'], help = "dataloader pin_memory")
#         parser.add_argument("-nb","--non_blocking", type=lambda x: str(x) in ['True', 'true', '1'], help = ".to(), set the non_blocking = ?")
#         parser.add_argument("-pf", '--prefetch', type=lambda x: str(x) in ['True', 'true', '1'], help='use prefetch')
#         parser.add_argument('--amp', type=lambda x: str(x) in ['True','true','1'])

#         parser.add_argument('--checkpoint_load', type=str, help='the location of load model')
#         parser.add_argument('--checkpoint_save', type=str, help='the location of checkpoint where model is saved')
#         parser.add_argument('--log', type=str, help='the location of log')
#         parser.add_argument("--dataset_path", type=str, help='the location of data')
#         parser.add_argument('--dataset', type=str, help='mnist, cifar10, cifar100, gtrsb, tiny') 
#         parser.add_argument('--result_file', type=str, help='the location of result')
    
#         parser.add_argument('--epochs', type=int)
#         parser.add_argument('--batch_size', type=int)
#         parser.add_argument("--num_workers", type=float)
#         parser.add_argument('--lr', type=float)
#         parser.add_argument('--lr_scheduler', type=str, help='the scheduler of lr')
#         parser.add_argument('--steplr_stepsize', type=int)
#         parser.add_argument('--steplr_gamma', type=float)
#         parser.add_argument('--steplr_milestones', type=list)
#         parser.add_argument('--model', type=str, help='resnet18')
        
#         parser.add_argument('--client_optimizer', type=int)
#         parser.add_argument('--sgd_momentum', type=float)
#         parser.add_argument('--wd', type=float, help='weight decay of sgd')
#         parser.add_argument('--frequency_save', type=int,
#                         help=' frequency_save, 0 is never')
#         parser.add_argument('--print_freq', default=1, type=int,help=' print_freq')
#         parser.add_argument('--random_seed', type=int, help='random seed')
#         parser.add_argument('--yaml_path', type=str, default="./config/defense/ft-sam/config.yaml", help='the path of yaml')
#         parser.add_argument('--bd_yaml_path', type=str, default=None, help='the path of yaml')

#         #set the parameter for the dsam defense
#         parser.add_argument('--ratio', type=float, help='the ratio of clean data loader')
#         parser.add_argument('--index', type=str, help='index of clean data')

#         parser.add_argument("--rho", default=2.0, type=float, help="Rho parameter for SAM.")
#         parser.add_argument("--adaptive", action='store_false', help="True if you want to use the Adaptive SAM.")
#         parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
#         parser.add_argument("--rho_max", default=2.0, type=float, help="Rho parameter for SAM.")
#         parser.add_argument("--rho_min", default=2.0, type=float, help="Rho parameter for SAM.")
#         parser.add_argument("--alpha", default=0.0, type=float, help="Rho parameter for SAM.")
#         parser.add_argument("--checkpoint_path", default=None, type=str, help="specify the checkpoint")

#     def set_result(self, result_file):
#         attack_file = 'record/' + result_file
#         # save_path = 'record/' + result_file + f'/defense/epochs_{args.epochs}_dsam_{args.ratio}_lr_{args.lr}_rho_{args.rho}/'
#         save_path = 'record/' + result_file + f'/defense/ft-sam/'
#         self.args.save_path = save_path
#         if self.args.checkpoint_save is None:
#             self.args.checkpoint_save = save_path + 'checkpoint/'
#             if not (os.path.exists(self.args.checkpoint_save)):
#                 os.makedirs(self.args.checkpoint_save) 
#         if self.args.log is None:
#             self.args.log = save_path + 'log/'
#             if not (os.path.exists(self.args.log)):
#                 os.makedirs(self.args.log)  
#         self.result = load_attack_result(attack_file + '/attack_result.pt')

#     def set_logger(self):
#         args = self.args
#         logFormatter = logging.Formatter(
#             fmt='%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s',
#             datefmt='%Y-%m-%d:%H:%M:%S',
#         )
#         logger = logging.getLogger()

#         fileHandler = logging.FileHandler(args.log + '/' + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + '.log')
#         fileHandler.setFormatter(logFormatter)
#         logger.addHandler(fileHandler)

#         consoleHandler = logging.StreamHandler()
#         consoleHandler.setFormatter(logFormatter)
#         logger.addHandler(consoleHandler)

#         logger.setLevel(logging.INFO)
#         logging.info(pformat(args.__dict__))
    
#     def set_devices(self):
#         self.device = torch.device(
#             (
#                 f"cuda:{[int(i) for i in self.args.device[5:].split(',')][0]}" if "," in self.args.device else self.args.device
#                 # since DataParallel only allow .to("cuda")
#             ) if torch.cuda.is_available() else "cpu"
#         )

#     def eval_step(self, model, clean_test_loader, bd_test_loader, args):
#         clean_metrics, clean_epoch_predict_list, clean_epoch_label_list = given_dataloader_test(
#             model,
#             clean_test_loader,
#             criterion=torch.nn.CrossEntropyLoss(),
#             non_blocking=args.non_blocking,
#             device=self.device,
#             verbose=0,
#         )
#         clean_test_loss_avg_over_batch = clean_metrics['test_loss_avg_over_batch']
#         test_acc = clean_metrics['test_acc']
#         bd_metrics, bd_epoch_predict_list, bd_epoch_label_list = given_dataloader_test(
#             model,
#             bd_test_loader,
#             criterion=torch.nn.CrossEntropyLoss(),
#             non_blocking=args.non_blocking,
#             device=self.device,
#             verbose=0,
#         )
#         bd_test_loss_avg_over_batch = bd_metrics['test_loss_avg_over_batch']
#         test_asr = bd_metrics['test_acc']

#         bd_test_loader.dataset.wrapped_dataset.getitem_all_switch = True  # change to return the original label instead
#         ra_metrics, ra_epoch_predict_list, ra_epoch_label_list = given_dataloader_test(
#             model,
#             bd_test_loader,
#             criterion=torch.nn.CrossEntropyLoss(),
#             non_blocking=args.non_blocking,
#             device=self.device,
#             verbose=0,
#         )
#         ra_test_loss_avg_over_batch = ra_metrics['test_loss_avg_over_batch']
#         test_ra = ra_metrics['test_acc']
#         bd_test_loader.dataset.wrapped_dataset.getitem_all_switch = False  # switch back

#         return clean_test_loss_avg_over_batch, \
# 				bd_test_loss_avg_over_batch, \
# 				ra_test_loss_avg_over_batch, \
# 				test_acc, \
# 				test_asr, \
# 				test_ra

#     def _train_sam(self, args, train_loader, model, optimizer, scheduler,criterion, epoch):
#         model.train()
#         losses = AverageMeter()
#         top1 = AverageMeter()

#         for idx, (img, target, *flag) in enumerate(train_loader, start=1):
#             img = img.to(args.device)
#             target = target.to(args.device)
#             bsz = target.shape[0]
#             def loss_fn(predictions, targets):
#                 return smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing).mean()
#             optimizer.set_closure(loss_fn, img, target)
#             predictions, loss = optimizer.step()
#             with torch.no_grad():
#                 correct = torch.argmax(predictions.data, 1) == target
#                 correct = correct.sum()
#                 scheduler.step()
#                 optimizer.update_rho_t()

#             # update metric

#             losses.update(loss.item(), bsz)
#             top1.update(correct.detach().cpu().numpy()/bsz, bsz)
#             # acc1, acc5 = accuracy(output, target, topk=(1, 5))
#             # top1.update(acc1[0].detach().cpu().numpy(), bsz)
#             if (idx + 1) % args.print_freq == 0:
#                 logging.info(f'Train: [{epoch}][{idx + 1}/{len(train_loader)}]\t \
#                     loss {losses.val} ({losses.avg}\t \
#                     Acc@1 {top1.val} ({top1.avg}')
#                 sys.stdout.flush()
   
#         del loss, img
#         torch.cuda.empty_cache()
#         return losses.avg, top1.avg, model

#     def train_sam(self, model,train_dataloader,
#                                    clean_test_dataloader,
#                                    bd_test_dataloader,
#                                    total_epoch_num,
#                                    criterion,
#                                    optimizer,
#                                    scheduler,
#                                    amp,
#                                    device,
#                                    frequency_save,
#                                    save_folder_path,
#                                    save_prefix,
#                                    prefetch,
#                                    prefetch_transform_attr_name,
#                                    non_blocking,
#                                    ):
        
      
#         criterion = criterion.to(args.device)

#         # Training and Testing
#         train_loss_list = []
#         train_mix_acc_list = []
#         clean_test_loss_list = []
#         bd_test_loss_list = []
#         test_acc_list = []
#         test_asr_list = []
#         test_ra_list = []
#         agg = Metric_Aggregator()


#         for epoch in tqdm(range(1, args.epochs+1)):
#             train_epoch_loss_avg_over_batch, \
#             train_mix_acc, \
#             model = self._train_sam(args, train_dataloader, model, optimizer, scheduler,criterion, epoch)

#             clean_test_loss_avg_over_batch, \
# 			bd_test_loss_avg_over_batch, \
# 			ra_test_loss_avg_over_batch, \
# 			test_acc, \
# 			test_asr, \
# 			test_ra = self.eval_step(
# 				model,
# 				clean_test_dataloader,
# 				bd_test_dataloader,
# 				args,
# 			)
#             train_loss_list.append(train_epoch_loss_avg_over_batch)
#             train_mix_acc_list.append(train_mix_acc)
            
#             clean_test_loss_list.append(clean_test_loss_avg_over_batch)
#             bd_test_loss_list.append(bd_test_loss_avg_over_batch)
#             test_acc_list.append(test_acc)
#             test_asr_list.append(test_asr)
#             test_ra_list.append(test_ra)
#             agg(
#                     {
#                         "train_epoch_loss_avg_over_batch": train_epoch_loss_avg_over_batch,
#                         "train_acc": train_mix_acc,
#                         "clean_test_loss_avg_over_batch": clean_test_loss_avg_over_batch,
#                         "bd_test_loss_avg_over_batch" : bd_test_loss_avg_over_batch,
#                         "test_acc" : test_acc,
#                         "test_asr" : test_asr,
#                         "test_ra" : test_ra,
#                     }
#             )
#             agg.to_dataframe().to_csv(f"{args.log}d-sam_df.csv")

#         agg.summary().to_csv(f"{args.log}d-sam_df_summary.csv")

#         return model


#     def mitigation(self):
#         args=self.args
#         self.set_devices()
#         fix_random(self.args.random_seed)

#         # Prepare model, optimizer, scheduler
#         model = generate_cls_model(self.args.model,self.args.num_classes)
    
       
#         if hasattr(args,"checkpoint_path") and args.checkpoint_path != None:
#             file_path = 'record/' + args.checkpoint_path 
#             checkpoint_path = load_attack_result(file_path + '/defense_result.pt')
#             model.load_state_dict(checkpoint_path['model'])
#         else:
#             model.load_state_dict(self.result['model'])

#         if "," in self.args.device:
#             self.model = torch.nn.DataParallel(
#                 self.model,
#                 device_ids=[int(i) for i in args.device[5:].split(",")]  # eg. "cuda:2,3,7" -> [2,3,7]
#             )
#         else:
#             model.to(self.args.device)
#         base_optimizer, scheduler = argparser_opt_scheduler(model, self.args)
   
#         rho_scheduler = ProportionScheduler(pytorch_lr_scheduler=scheduler, max_lr=args.lr, min_lr=0.0,
#             max_value=args.rho_max, min_value=args.rho_min)
#         optimizer = SAM(params=model.parameters(), base_optimizer=base_optimizer, model=model, sam_alpha=args.alpha, rho_scheduler=rho_scheduler, adaptive=args.adaptive)
        

#         # criterion = nn.CrossEntropyLoss()
#         criterion = argparser_criterion(args)

#         train_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = True)
#         clean_dataset = prepro_cls_DatasetBD_v2(self.result['clean_train'].wrapped_dataset)
#         # data_all_length = len(clean_dataset)
#         # ran_idx = choose_index(self.args, data_all_length) 
#         ran_idx = choose_by_class(args,clean_dataset)
#         log_index = self.args.log + 'index.txt'
#         np.savetxt(log_index, ran_idx, fmt='%d')
#         clean_dataset.subset(ran_idx)
#         data_set_without_tran = clean_dataset
#         data_set_o = self.result['clean_train']
#         data_set_o.wrapped_dataset = data_set_without_tran
#         data_set_o.wrap_img_transform = train_tran
#         data_loader = torch.utils.data.DataLoader(data_set_o, batch_size=self.args.batch_size, num_workers=self.args.num_workers, shuffle=True, pin_memory=args.pin_memory)
#         trainloader = data_loader
        
#         test_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = False)
#         data_bd_testset = self.result['bd_test']
#         data_bd_testset.wrap_img_transform = test_tran
#         data_bd_loader = torch.utils.data.DataLoader(data_bd_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)

#         data_clean_testset = self.result['clean_test']
#         data_clean_testset.wrap_img_transform = test_tran
#         data_clean_loader = torch.utils.data.DataLoader(data_clean_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)

#         self.train_sam(
#             model,
#             trainloader,
#             data_clean_loader,
#             data_bd_loader,
#             args.epochs,
#             criterion=criterion,
#             optimizer=optimizer,
#             scheduler=scheduler,
#             device=self.device,
#             frequency_save=args.frequency_save,
#             save_folder_path=args.save_path,
#             save_prefix='dsam',
#             amp=args.amp,
#             prefetch=args.prefetch,
#             prefetch_transform_attr_name="ori_image_transform_in_loading", # since we use the preprocess_bd_dataset
#             non_blocking=args.non_blocking,
#         )
        
#         result = {}
#         result['model'] = model

#         save_defense_result(
#             model_name=args.model,
#             num_classes=args.num_classes,
#             model=model.cpu().state_dict(),
#             save_path=args.save_path,
#         )
#         return result

#     def defense(self,result_file):
#         self.set_result(result_file)
#         self.set_logger()
#         result = self.mitigation()
#         return result
    
# if __name__ == '__main__':
#     parser = argparse.ArgumentParser(description=sys.argv[0])
#     dsam.add_arguments(parser)
#     args = parser.parse_args()
#     dsam_method = dsam(args)
#     result = dsam_method.defense(args.result_file)