"""
This file implements the defense method called D-ST from Effective Backdoor Defense by Exploiting Sensitivity of Poisoned Samples.
It trains a !!!secure model!!! from scratch with a poisoned dataset.
This file is modified based on the following source:
link :  https://github.com/SCLBD/Effective_backdoor_defense
The defense method is called d-br.


The update include:
    1. data preprocess and dataset setting
    2. model setting
    3. args and config
    4. save process
    5. new standard: robust accuracy
basic sturcture for defense method:
    1. basic setting: args
    2. attack result(model, train data, test data)
    3. d-st defense: mainly two steps: sd and st (Sample-Distinguishment and two-stage Secure Training)
        a. train a backdoored model from scratch using poisoned dataset without any data augmentations
        b. fine-tune the backdoored model with intra-class loss L_intra.
        (sd:)
        c. calculate values of the FCT metric for all training samples.
        d. calculate thresholds for choosing clean and poisoned samples.
        e. separate training samples into clean samples D_c, poisoned samples D_p, and uncertain samples D_u.
        (st:)
        f. train the feature extractor via semi-supervised contrastive learning.
        g. train the classifier via minimizing a mixed cross-entropy loss.
    4. test the result and get ASR, ACC, RC 

"""

import argparse
import os, sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import copy
import math

sys.path.append("../")
sys.path.append(os.getcwd())

from pprint import pformat
import yaml
import logging
import time
from defense.base import defense

from utils.aggregate_block.train_settings_generate import (
    argparser_criterion,
    argparser_opt_scheduler,
)
from utils.trainer_cls import (
    BackdoorModelTrainer,
    Metric_Aggregator,
    PureCleanModelTrainer,
)
from utils.choose_index import choose_index
from utils.aggregate_block.fix_random import fix_random
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.log_assist import get_git_info
from utils.aggregate_block.dataset_and_transform_generate import (
    get_input_shape,
    get_num_classes,
    get_transform,
)
from utils.save_load_attack import load_attack_result, save_defense_result

## d-st utils
from utils.defense_utils.dst.dataloader_bd import (
    get_transform_st,
    TransformThree,
    normalization,
)
from utils.defense_utils.dst.sd import (
    calculate_consistency,
    calculate_gamma,
    separate_samples,
)
from utils.defense_utils.dst.dataloader_bd import get_st_train_loader
from utils.defense_utils.dst.models.resnet_super import SupConResNet, LinearClassifier
from utils.defense_utils.dst.st_loss import SupConLoss_Consistency
from utils.defense_utils.dst.utils_st import *


def train_epoch(arg, trainloader, model, optimizer, scheduler, criterion, epoch):
    model.train()

    total_clean, total_poison = 0, 0
    total_clean_correct, total_attack_correct, total_robust_correct = 0, 0, 0
    train_loss = 0

    for i, (inputs, labels, _, isCleans, gt_labels) in enumerate(trainloader):
        inputs = normalization(arg, inputs[0])  # Normalize
        inputs, labels, gt_labels = (
            inputs.to(arg.device),
            labels.to(arg.device),
            gt_labels.to(arg.device),
        )
        clean_idx, poison_idx = torch.where(isCleans == True), torch.where(isCleans == False)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total_clean_correct += torch.sum(torch.argmax(outputs[:], dim=1) == labels[:])
        total_attack_correct += torch.sum(torch.argmax(outputs[poison_idx], dim=1) == labels[poison_idx])
        total_robust_correct += torch.sum(torch.argmax(outputs[:], dim=1) == gt_labels[:])
        total_clean += inputs.shape[0]
        total_poison += inputs[poison_idx].shape[0]

    avg_acc_clean = (total_clean_correct / total_clean).item()
    avg_acc_attack = (total_attack_correct / total_poison).item()
    avg_acc_robust = (total_robust_correct / total_clean).item()
    logging.info(
        f"Epoch: {epoch} | Loss: {train_loss / (i + 1)} | Train ACC: {avg_acc_clean} ({total_clean_correct}/{total_clean}) | Train ASR: \
        {avg_acc_attack}% ({total_attack_correct}/{total_poison}) | Train R-ACC: {avg_acc_robust} ({total_robust_correct}/{total_clean})"
    )
    del loss, inputs, outputs
    torch.cuda.empty_cache()
    scheduler.step()
    return train_loss / (i + 1), avg_acc_clean, avg_acc_attack, avg_acc_robust


def test_epoch(args, testloader, model, criterion, epoch):
    model.eval()

    total_clean = 0
    total_clean_correct = 0
    test_loss = 0

    for i, (inputs, labels, *additional_info) in enumerate(testloader):
        inputs, labels = inputs.to(args.device), labels.to(args.device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        test_loss += loss.item()
        total_clean_correct += torch.sum(torch.argmax(outputs[:], dim=1) == labels[:])
        total_clean += inputs.shape[0]
    avg_acc_clean = (total_clean_correct / total_clean).item()

    return test_loss / (i + 1), avg_acc_clean


def finetune_epoch(arg, trainloader, model, optimizer, scheduler, epoch):
    model.train()

    total_clean, total_poison = 0, 0
    total_clean_correct, total_attack_correct, total_robust_correct = 0, 0, 0
    train_loss = 0

    for i, (inputs, labels, _, is_bd, gt_labels) in enumerate(trainloader):
        inputs = normalization(arg, inputs[0])  # Normalize
        inputs, labels, gt_labels = (
            inputs.to(arg.device),
            labels.to(arg.device),
            gt_labels.to(arg.device),
        )
        clean_idx, poison_idx = (
            torch.where(is_bd == False)[0],
            torch.where(is_bd == True)[0],
        )

        # Features and Outputs
        # outputs = model(inputs)
        # if hasattr(model, "module"):   # abandon FC layer
        #     features_out = list(model.module.children())[:-1]
        # else:
        #     features_out = list(model.children())[:-1]
        # modelout = nn.Sequential(*features_out).to(arg.device)
        # features = modelout(inputs)
        # features = features.view(features.size(0), -1)
        features = model(inputs)
        features = features.view(features.size(0), -1)
        # Calculate intra-class loss
        centers = []
        for j in range(arg.num_classes):
            j_idx = torch.where(labels == j)[0]
            if j_idx.shape[0] == 0:
                continue
            j_features = features[j_idx]
            j_center = torch.mean(j_features, dim=0)
            centers.append(j_center)

        centers = torch.stack(centers, dim=0)
        centers = F.normalize(centers, dim=1)
        similarity_matrix = torch.matmul(centers, centers.T)
        mask = torch.eye(similarity_matrix.shape[0], dtype=torch.bool).to(arg.device)
        similarity_matrix[mask] = 0.0
        loss = torch.mean(similarity_matrix)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    scheduler.step()
    torch.cuda.empty_cache()
    # return train_loss / (i + 1), avg_acc_clean, avg_acc_attack, avg_acc_robust
    return train_loss / (i + 1)


def _train_extractor(train_loader, model, criterion, optimizer, epoch, args):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels, flags) in enumerate(train_loader):
        if args.debug and idx == 2:
            break
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True).to(args.device)
            labels = labels.cuda(non_blocking=True).to(args.device)
            flags = flags.cuda(non_blocking=True).to(args.device)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)

        # compute loss
        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        loss = criterion(features, labels, flags)

        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % args.print_freq == 0:
            logging.info(
                f"Train: [{epoch}/{args.epochs}][{idx + 1}/{len(train_loader)}]\t \
                BT {batch_time.val} ({batch_time.avg})\t \
                DT {data_time.val} ({data_time.avg})\t \
                loss {losses.val} ({losses.avg})"
            )

            sys.stdout.flush()
    del loss, images, features
    torch.cuda.empty_cache()
    return losses.avg


def _train_classifier(train_loader, model, classifier, criterion, optimizer, epoch, args):
    """one epoch training"""
    model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    for idx, (images, labels, flags) in enumerate(train_loader):
        if args.debug and idx == 2:
            break
        data_time.update(time.time() - end)
        images = images.cuda(non_blocking=True).to(args.device)
        labels = labels.cuda(non_blocking=True).to(args.device)
        flags = flags.cuda(non_blocking=True).to(args.device)

        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)

        # compute loss
        with torch.no_grad():
            features = model.encoder(images)
        output = classifier(features.detach())

        clean_idx = torch.where(flags == 0)[0]
        poison_idx = torch.where(flags == 2)[0]
        loss = (
            criterion(output[clean_idx], labels[clean_idx]) - criterion(output[poison_idx], labels[poison_idx]) * 0.001
        )
        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update metric
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0].detach().cpu().numpy(), bsz)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % args.print_freq == 0:
            logging.info(
                f"Train: [{epoch}][{idx + 1}/{len(train_loader)}]\t \
                BT {batch_time.val} ({batch_time.avg})\t \
                DT {data_time.val} ({data_time.avg})\t \
                loss {losses.val} ({losses.avg}\t \
                Acc@1 {top1.val} ({top1.avg}"
            )
            sys.stdout.flush()
    del loss, features, images, output
    torch.cuda.empty_cache()
    return losses.avg, top1.avg


def given_dataloader_test(
    model,
    classifier,
    test_dataloader,
    criterion,
    non_blocking: bool = False,
    device="cpu",
    verbose: int = 0,
):
    model.to(device, non_blocking=non_blocking)
    model.eval()
    metrics = {
        "test_correct": 0,
        "test_loss_sum_over_batch": 0,
        "test_total": 0,
    }
    criterion = criterion.to(device, non_blocking=non_blocking)

    if verbose == 1:
        batch_predict_list, batch_label_list = [], []

    with torch.no_grad():
        for batch_idx, (x, target, *additional_info) in enumerate(test_dataloader):
            x = x.to(device, non_blocking=non_blocking)
            target = target.to(device, non_blocking=non_blocking)
            features = model.encoder(x)
            pred = classifier(features.detach())
            loss = criterion(pred, target.long())

            _, predicted = torch.max(pred, -1)
            correct = predicted.eq(target).sum()

            if verbose == 1:
                batch_predict_list.append(predicted.detach().clone().cpu())
                batch_label_list.append(target.detach().clone().cpu())

            metrics["test_correct"] += correct.item()
            metrics["test_loss_sum_over_batch"] += loss.item()
            metrics["test_total"] += target.size(0)

    metrics["test_loss_avg_over_batch"] = metrics["test_loss_sum_over_batch"] / len(test_dataloader)
    metrics["test_acc"] = metrics["test_correct"] / metrics["test_total"]

    if verbose == 0:
        return metrics, None, None
    elif verbose == 1:
        return metrics, torch.cat(batch_predict_list), torch.cat(batch_label_list)


def reset_model_from_SupConResNet(args, old_model, classifier):  ## replace the parameters from old model to new model
    new_model = generate_cls_model(args.model, args.num_classes)

    new_dict = new_model.state_dict()
    old_dict = old_model.encoder.state_dict()
    new_dict.update(old_dict)
    new_model.load_state_dict(new_dict)
    if hasattr(new_model, "linear"):
        new_model.linear.weight.data = classifier.fc.weight.data
        new_model.linear.bias.data = classifier.fc.bias.data
    elif hasattr(new_model, "fc"):
        new_model.fc.weight.data = classifier.fc.weight.data
        new_model.fc.bias.data = classifier.fc.bias.data
    return new_model


class d_st(defense):
    def __init__(self, args):
        with open(args.yaml_path, "r") as f:
            defaults = yaml.safe_load(f)

        defaults.update({k: v for k, v in args.__dict__.items() if v is not None})

        args.__dict__ = defaults

        args.terminal_info = sys.argv

        args.num_classes = get_num_classes(args.dataset)
        args.input_height, args.input_width, args.input_channel = get_input_shape(args.dataset)
        args.img_size = (args.input_height, args.input_width, args.input_channel)
        # args.dataset_path = f"{args.dataset_path}/{args.dataset}"
        self.args = args

    def add_arguments(parser):
        parser.add_argument("--device", type=str, help="cuda, cpu")
        parser.add_argument(
            "-pm",
            "--pin_memory",
            type=lambda x: str(x) in ["True", "true", "1"],
            help="dataloader pin_memory",
        )
        parser.add_argument(
            "-nb",
            "--non_blocking",
            type=lambda x: str(x) in ["True", "true", "1"],
            help=".to(), set the non_blocking = ?",
        )
        parser.add_argument(
            "-pf",
            "--prefetch",
            type=lambda x: str(x) in ["True", "true", "1"],
            help="use prefetch",
        )
        parser.add_argument("--amp", type=lambda x: str(x) in ["True", "true", "1"])

        parser.add_argument("--checkpoint_load", type=str, help="the location of load model")
        parser.add_argument(
            "--checkpoint_save",
            type=str,
            help="the location of checkpoint where model is saved",
        )
        parser.add_argument("--log", type=str, help="the location of log")
        parser.add_argument("--dataset_path", type=str, help="the location of data")
        parser.add_argument("--dataset", type=str, help="mnist, cifar10, cifar100, gtrsb, tiny")
        parser.add_argument("--result_file", type=str, help="the location of result")
        parser.add_argument("--random_seed", type=int, help="random seed")
        parser.add_argument(
            "--yaml_path",
            type=str,
            default="./config/defense/d-st/config.yaml",
            help="the path of yaml",
        )

        parser.add_argument("--epochs", type=int)
        parser.add_argument("--batch_size", type=int)
        parser.add_argument("--num_workers", type=float)
        parser.add_argument("--lr", type=float)
        parser.add_argument("--lr_scheduler", type=str, help="the scheduler of lr")
        parser.add_argument("--steplr_stepsize", type=int)
        parser.add_argument("--steplr_gamma", type=float)
        parser.add_argument("--steplr_milestones", type=list)
        parser.add_argument("--model", type=str, help="resnet18")
        parser.add_argument("--target_label", type=int)
        # parser.add_argument('--client_optimizer', type=int)
        parser.add_argument("--sgd_momentum", type=float)
        parser.add_argument("--wd", type=float, help="weight decay of sgd")
        parser.add_argument("--frequency_save", type=int, help=" frequency_save, 0 is never")

        parser.add_argument("--momentum", type=float, help="momentum")
        parser.add_argument("--weight_decay", type=float, help="weight decay")

        # set the parameter for the d-st defense
        parser.add_argument("--continue_step", type=str, default=None, help="the step to continue")
        parser.add_argument("--gamma_low", type=float, default=None, help="<=gamma_low is clean")  # \gamma_c
        parser.add_argument("--gamma_high", type=float, default=None, help=">=gamma_high is poisoned")  # \gamma_p
        parser.add_argument("--clean_ratio", type=float, default=0.20, help="ratio of clean data")  # \alpha_c
        parser.add_argument("--poison_ratio", type=float, default=0.05, help="ratio of poisoned data")  # \alpha_p

        parser.add_argument(
            "--gamma",
            type=float,
            default=0.1,
            help="LR is multiplied by gamma on schedule.",
        )
        parser.add_argument(
            "--schedule",
            type=int,
            nargs="+",
            default=[100, 150],
            help="Decrease learning rate at these epochs.",
        )
        parser.add_argument("--warm", type=int, default=1, help="warm up training phase")

        parser.add_argument("--trans1", type=str, default="rotate")  # the first data augmentation
        parser.add_argument("--trans2", type=str, default="affine")  # the second data augmentation
        parser.add_argument("--debug", action="store_true", default=False, help="debug or not")
        parser.add_argument("--print_freq", type=int, default=10, help="print frequency")
        parser.add_argument("--save_all_process", action="store_true", help="save model in each process")

    def set_result(self, result_file):
        attack_file = "record/" + result_file
        save_path = "record/" + result_file + "/defense/d-st/"
        if not (os.path.exists(save_path)):
            os.makedirs(save_path)
        # assert(os.path.exists(save_path))
        self.args.save_path = save_path
        if self.args.checkpoint_save is None:
            self.args.checkpoint_save = save_path + "checkpoint/"
            if not (os.path.exists(self.args.checkpoint_save)):
                os.makedirs(self.args.checkpoint_save)
        if self.args.log is None:
            self.args.log = save_path + "log/"
            if not (os.path.exists(self.args.log)):
                os.makedirs(self.args.log)
        self.result = load_attack_result(attack_file + "/attack_result.pt")

    def set_trainer(self, model, mode="normal"):
        if mode == "normal":
            self.trainer = BackdoorModelTrainer(
                model,
            )
        elif mode == "clean":
            self.trainer = PureCleanModelTrainer(
                model,
            )
        elif mode == "nad":
            raise RuntimeError("No trainer support this mode!")

    def set_logger(self):
        args = self.args
        logFormatter = logging.Formatter(
            fmt="%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s",
            datefmt="%Y-%m-%d:%H:%M:%S",
        )
        logger = logging.getLogger()

        fileHandler = logging.FileHandler(
            args.log + "/" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log"
        )
        fileHandler.setFormatter(logFormatter)
        logger.addHandler(fileHandler)

        consoleHandler = logging.StreamHandler()
        consoleHandler.setFormatter(logFormatter)
        logger.addHandler(consoleHandler)

        logger.setLevel(logging.INFO)
        logging.info(pformat(args.__dict__))

        try:
            logging.info(pformat(get_git_info()))
        except:
            logging.info("Getting git info fails.")

    def set_devices(self):
        # self.device = torch.device(
        #     (
        #         f"cuda:{[int(i) for i in self.args.device[5:].split(',')][0]}" if "," in self.args.device else self.args.device
        #         # since DataParallel only allow .to("cuda")
        #     ) if torch.cuda.is_available() else "cpu"
        # )
        self.device = self.args.device

    def set_new_args(self, args, step):
        if step == "train_notrans":
            args.epochs = 2
            args.batch_size = 128
        elif step == "finetune_notrans":
            args.epochs = 10
        elif step == "sscl":
            args.epochs = 200
            args.learning_rate = 0.5
            args.temp = 0.1
            args.batch_size = 512
            args.cosine = True
            if args.cosine:
                args.model_name = "{}_cosine".format(args.model)
            if args.batch_size > 256:
                args.warm = True
            if args.warm:
                args.model_name = "{}_warm".format(args.model)
                args.warmup_from = 0.01
                args.warm_epochs = 10
                if args.cosine:
                    args.lr_decay_rate = 0.1
                    eta_min = args.learning_rate * (args.lr_decay_rate**3)
                    args.warmup_to = (
                        eta_min
                        + (args.learning_rate - eta_min) * (1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
                    )
                else:
                    args.warmup_to = args.learning_rate
                    args.lr_decay_epochs = [700, 800, 900]
        elif step == "mixed_ce":
            args.epochs = 10
            args.learning_rate = 5
            args.batch_size = 512
            args.num_workers = 16
            args.cosine = False
            if args.batch_size > 256:
                args.warm = True
            if args.warm:
                args.model_name = "{}_warm".format(args.model)
                args.warmup_from = 0.01
                args.warm_epochs = 10
                if args.cosine:
                    args.lr_decay_rate = 0.1
                    eta_min = args.learning_rate * (args.lr_decay_rate**3)
                    args.warmup_to = (
                        eta_min
                        + (args.learning_rate - eta_min) * (1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
                    )
                else:
                    args.warmup_to = args.learning_rate
                    args.lr_decay_epochs = [60, 75, 90]
        if args.debug:
            args.epochs = 1
        return args

    def set_model(self, args, model):
        assert isinstance(model, SupConResNet)
        criterion = torch.nn.CrossEntropyLoss()
        classifier = LinearClassifier(feat_dim=args.feature_dim, num_classes=args.num_classes)
        if "," in self.device:
            model = torch.nn.DataParallel(
                model,
                device_ids=[int(i) for i in args.device[5:].split(",")],  # eg. "cuda:2,3,7" -> [2,3,7]
            )
            self.args.device = f"cuda:{model.device_ids[0]}"
            model.to(self.args.device)
        else:
            model.to(self.args.device)
        classifier = classifier.to(args.device)
        criterion = criterion.to(args.device)
        return model, classifier, criterion

    def drop_linear(self, model):  # drop the last nn.Linear layer, which will not be used in the following training
        model_name = self.args.model
        if "preactresnet" in model_name or model_name == "senet18":
            feature_dim = model.linear.in_features
            model.linear = nn.Identity()
        elif model_name.startswith("resnet"):
            feature_dim = model.fc.in_features
            model.fc = nn.Identity()
        elif "vgg" in model_name or "convnext" in model_name:
            feature_dim = list(model.classifier.children())[-1].in_features
            model.classifier = nn.Sequential(*list(model.classifier.children())[:-1])
        elif "vit" in model_name:
            feature_dim = model[1].heads.head.in_features
            model[1].heads.head = nn.Identity()
        else:
            raise NotImplementedError("Not support the model: {}".format(model_name))
        model.register_feature_dim = feature_dim
        return model

    def add_linear(self, old_model, classifier):  ## replace the parameters from old model to new model
        args = self.args
        new_model = generate_cls_model(args.model, args.num_classes)
        new_dict = new_model.state_dict()
        old_dict = old_model.encoder.state_dict()
        new_dict.update(old_dict)
        new_model.load_state_dict(new_dict)
        model_name = args.model
        fc = classifier.fc
        if "preactresnet" in model_name or model_name == "senet18":
            new_model.linear = fc
        elif model_name.startswith("resnet"):
            new_model.fc = fc
        elif "vgg" in model_name or "convnext" in model_name:
            new_model.classifier = nn.Sequential(*list(new_model.classifier.children())[:-1] + [fc])
        elif "vit" in model_name:
            new_model[1].heads.head = fc
        else:
            raise NotImplementedError("Not support the model: {}".format(model_name))
        return new_model

    def get_sd_train_loader(self):
        args = self.args
        transform1, transform2, transform3 = get_transform_st(args, train=True)
        dataset_train = self.result["bd_train"]
        dataset_train.wrap_img_transform = TransformThree(transform1, transform2, transform3)
        poisoned_data_loader_train = torch.utils.data.DataLoader(
            dataset_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            shuffle=True,
        )
        return poisoned_data_loader_train

    def testloader_wrapper(
        self,
    ):
        args = self.args
        test_tran = get_transform(args.dataset, *([args.input_height, args.input_width]), train=False)

        data_bd_testset = self.result["bd_test"]
        data_bd_testset.wrap_img_transform = test_tran
        bd_test_loader = torch.utils.data.DataLoader(
            data_bd_testset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            drop_last=False,
            shuffle=False,
            pin_memory=True,
        )

        data_clean_testset = self.result["clean_test"]
        data_clean_testset.wrap_img_transform = test_tran
        clean_test_loader = torch.utils.data.DataLoader(
            data_clean_testset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            drop_last=False,
            shuffle=False,
            pin_memory=True,
        )

        return clean_test_loader, bd_test_loader

    def train_attack_noTrans(
        self,
        bd_trainloader,
        clean_test_loader,
        bd_test_loader,
        model=None,
        optimizer=None,
        scheduler=None,
        finetune=False,
    ):
        ## update args
        step = "finetune_notrans" if finetune else "train_notrans"
        args = self.set_new_args(self.args, step=step)
        agg = Metric_Aggregator()
        if not finetune:
            # Load models
            logging.info("----------- Network Initialization --------------")
            model = generate_cls_model(args.model, args.num_classes)
            if "," in self.device:
                model = torch.nn.DataParallel(
                    model,
                    device_ids=[int(i) for i in args.device[5:].split(",")],  # eg. "cuda:2,3,7" -> [2,3,7]
                )
                self.args.device = f"cuda:{model.device_ids[0]}"
                model.to(self.args.device)
            else:
                model.to(self.args.device)
            logging.info("finished model init...")
            # initialize optimizer
            # optimizer = set_optimizer(args,model)
            # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
            optimizer, scheduler = argparser_opt_scheduler(model, self.args)
            # define loss functions
            criterion = torch.nn.CrossEntropyLoss().to(args.device)

            logging.info("----------- Training from scratch --------------")
            for epoch in tqdm(range(0, args.epochs)):
                tr_loss, tr_acc, _, _ = train_epoch(args, bd_trainloader, model, optimizer, scheduler, criterion, epoch)
                clean_test_loss, clean_test_acc = test_epoch(args, clean_test_loader, model, criterion, epoch)

                bd_test_loss, bd_test_acc = test_epoch(args, bd_test_loader, model, criterion, epoch)
                bd_test_loader.dataset.wrapped_dataset.getitem_all_switch = True
                _, bd_test_racc = test_epoch(args, bd_test_loader, model, criterion, epoch)
                bd_test_loader.dataset.wrapped_dataset.getitem_all_switch = False
                agg(
                    {
                        "train_epoch_loss_avg_over_batch": tr_loss,
                        "train_acc": tr_acc,
                        "clean_test_loss_avg_over_batch": clean_test_loss,
                        "bd_test_loss_avg_over_batch": bd_test_loss,
                        "test_acc": clean_test_acc,
                        "test_asr": bd_test_acc,
                        "test_ra": bd_test_racc,
                    }
                )
                agg.to_dataframe().to_csv(f"{args.log}train_notrans_df.csv")
        else:
            # initialize optimizer
            # optimizer = set_optimizer(args,model)
            # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
            logging.info("----------- Finetune the model with L_intra--------------")
            for epoch in tqdm(range(0, args.epochs)):
                tr_loss = finetune_epoch(args, bd_trainloader, model, optimizer, scheduler, epoch)

                agg(
                    {
                        "epoch": epoch,
                        "train_epoch_loss_avg_over_batch": tr_loss,
                    }
                )
                agg.to_dataframe().to_csv(f"{args.log}finetune_notrans_df.csv")
        if args.save_all_process:
            save_file = os.path.join(args.save_path, f"{step}.pt")
            logging.info(f"save path is {save_file}")
            save_model(model, optimizer, args, args.epochs, save_file)
        return model, optimizer, scheduler

    def train_extractor(
        self,
    ):
        ## update args
        args = self.set_new_args(self.args, step="sscl")
        train_loader = get_st_train_loader(args, self.result["bd_train"], module="sscl")
        encoder = generate_cls_model(args.model, args.num_classes)
        encoder = self.drop_linear(encoder)
        args.feature_dim = encoder.register_feature_dim
        model = SupConResNet(encoder, dim_in=args.feature_dim)
        criterion = SupConLoss_Consistency(temperature=args.temp, device=args.device)
        model = model.to(args.device)
        criterion = criterion.to(args.device)
        optimizer = set_optimizer(args, model, lr=args.learning_rate)
        agg = Metric_Aggregator()

        for epoch in range(1, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            loss = _train_extractor(train_loader, model, criterion, optimizer, epoch, args)
            agg(
                {
                    "epoch": epoch,
                    "train_epoch_loss_avg_over_batch": loss,
                }
            )
            agg.to_dataframe().to_csv(f"{args.log}train_extractor_df.csv")
        del loss
        torch.cuda.empty_cache()
        if args.save_all_process:
            # save the last model
            save_file = os.path.join(args.save_path, "sscl-last.pt")
            save_model(model, optimizer, args, args.epochs, save_file)
        return model

    def train_classifier(self, model):
        ## update args
        args = self.set_new_args(self.args, step="mixed_ce")
        train_loader = get_st_train_loader(args, self.result["bd_train"], module="mixed_ce")
        clean_test_loader, bd_test_loader = self.testloader_wrapper()
        model, classifier, criterion = self.set_model(args, model)
        optimizer = set_optimizer(args, classifier, lr=args.learning_rate)

        train_loss_list = []
        train_mix_acc_list = []
        clean_test_loss_list = []
        bd_test_loss_list = []
        test_acc_list = []
        test_asr_list = []
        test_ra_list = []
        agg = Metric_Aggregator()

        for epoch in range(1, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            train_epoch_loss_avg_over_batch, train_mix_acc = _train_classifier(
                train_loader, model, classifier, criterion, optimizer, epoch, args
            )

            (
                clean_test_loss_avg_over_batch,
                bd_test_loss_avg_over_batch,
                ra_test_loss_avg_over_batch,
                test_acc,
                test_asr,
                test_ra,
            ) = self.eval_step(
                model,
                classifier,
                clean_test_loader,
                bd_test_loader,
                args,
            )
            train_loss_list.append(train_epoch_loss_avg_over_batch)
            train_mix_acc_list.append(train_mix_acc)
            clean_test_loss_list.append(clean_test_loss_avg_over_batch)
            bd_test_loss_list.append(bd_test_loss_avg_over_batch)
            test_acc_list.append(test_acc)
            test_asr_list.append(test_asr)
            test_ra_list.append(test_ra)
            agg(
                {
                    "train_epoch_loss_avg_over_batch": train_epoch_loss_avg_over_batch,
                    "train_acc": train_mix_acc,
                    "clean_test_loss_avg_over_batch": clean_test_loss_avg_over_batch,
                    "bd_test_loss_avg_over_batch": bd_test_loss_avg_over_batch,
                    "test_acc": test_acc,
                    "test_asr": test_asr,
                    "test_ra": test_ra,
                }
            )
            agg.to_dataframe().to_csv(f"{args.save_path}d-st_df.csv")

        agg.summary().to_csv(f"{args.save_path}d-st_df_summary.csv")
        if args.save_all_process:
            save_file = os.path.join(args.save_path, "mce-last.pt")
            save_model(classifier, optimizer, args, args.epochs, save_file)
        return model, classifier

    def eval_step(self, model, classifier, clean_test_loader, bd_test_loader, args):
        (
            clean_metrics,
            clean_epoch_predict_list,
            clean_epoch_label_list,
        ) = given_dataloader_test(
            model,
            classifier,
            clean_test_loader,
            criterion=torch.nn.CrossEntropyLoss(),
            non_blocking=args.non_blocking,
            device=self.device,
            verbose=0,
        )
        clean_test_loss_avg_over_batch = clean_metrics["test_loss_avg_over_batch"]
        test_acc = clean_metrics["test_acc"]
        bd_metrics, bd_epoch_predict_list, bd_epoch_label_list = given_dataloader_test(
            model,
            classifier,
            bd_test_loader,
            criterion=torch.nn.CrossEntropyLoss(),
            non_blocking=args.non_blocking,
            device=self.device,
            verbose=0,
        )
        bd_test_loss_avg_over_batch = bd_metrics["test_loss_avg_over_batch"]
        test_asr = bd_metrics["test_acc"]

        bd_test_loader.dataset.wrapped_dataset.getitem_all_switch = True  # change to return the original label instead
        ra_metrics, ra_epoch_predict_list, ra_epoch_label_list = given_dataloader_test(
            model,
            classifier,
            bd_test_loader,
            criterion=torch.nn.CrossEntropyLoss(),
            non_blocking=args.non_blocking,
            device=self.device,
            verbose=0,
        )
        ra_test_loss_avg_over_batch = ra_metrics["test_loss_avg_over_batch"]
        test_ra = ra_metrics["test_acc"]
        bd_test_loader.dataset.wrapped_dataset.getitem_all_switch = False  # switch back

        return (
            clean_test_loss_avg_over_batch,
            bd_test_loss_avg_over_batch,
            ra_test_loss_avg_over_batch,
            test_acc,
            test_asr,
            test_ra,
        )

    def continue_learn(self, args):
        step_list = [
            "train_notrans",
            "finetune_notrans",
            "calculate",
            "separate",
            "sscl",
            "mixed_ce",
        ]
        if args.continue_step == "mixed_ce":
            encoder = generate_cls_model(args.model, args.num_classes)
            args.feature_dim = list(encoder.named_modules())[-1][1].in_features
            if hasattr(encoder, "linear"):
                encoder.linear = nn.Identity()
            elif hasattr(encoder, "fc"):
                encoder.fc = nn.Identity()
            model = SupConResNet(encoder, dim_in=args.feature_dim)

            ck_path = os.path.join(args.save_path, "sscl-last.pt")
            result = torch.load(ck_path)
            model.load_state_dict(result["model"])
            model_new = model.to(args.device)
            return model_new

    def mitigation(self):
        args = self.args
        self.set_devices()
        fix_random(self.args.random_seed)
        result = self.result
        bd_trainloader = self.get_sd_train_loader()
        clean_test_loader, bd_test_loader = self.testloader_wrapper()
        ##a. train a backdoored model from scratch using poisoned dataset without any data augmentations
        model, optimizer, scheduler = self.train_attack_noTrans(
            bd_trainloader, clean_test_loader, bd_test_loader, finetune=False
        )
        ###b. fine-tune the backdoored model with intra-class loss L_intra
        model = self.drop_linear(model)
        model, optimizer, scheduler = self.train_attack_noTrans(
            bd_trainloader,
            clean_test_loader,
            bd_test_loader,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            finetune=True,
        )
        ###c. calculate values of the FCT metric for all training samples.
        calculate_consistency(args, bd_trainloader, model)
        ###d. calculate thresholds for choosing clean and poisoned samples.
        args.gamma_low, args.gamma_high = calculate_gamma(
            args,
        )
        ###e. separate training samples into clean samples D_c, poisoned samples D_p, and uncertain samples D_u.
        separate_samples(args, bd_trainloader, model)
        ##f. train the feature extractor (from scratch) via semi-supervised contrastive learning.
        model_new = self.train_extractor()
        ###g. train the classifier via minimizing a mixed cross-entropy loss.
        model_new, classifier = self.train_classifier(model_new)
        # return the standard model structure from two subnetworks: SupConResNet+classifier
        model_new = self.add_linear(old_model=model_new, classifier=classifier)
        result = {}
        result["model"] = model_new

        save_defense_result(
            model_name=args.model,
            num_classes=args.num_classes,
            model=model_new.cpu().state_dict(),
            save_path=args.save_path,
        )
        return result

    def defense(self, result_file):
        self.set_result(result_file)
        self.set_logger()
        result = self.mitigation()
        return result


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=sys.argv[0])
    d_st.add_arguments(parser)
    args = parser.parse_args()
    d_st_method = d_st(args)
    if "result_file" not in args.__dict__ or args.result_file is None:
        args.result_file = "defense_test_badnet"
    result = d_st_method.defense(args.result_file)
