import copy
import os.path as osp

from tqdm import tqdm
import time

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.nn.modules.loss import _Loss

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing, save_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.evaluation import build_evaluator

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from clip.model import VisionTransformer

# from modules.resnet import MaskModifiedResNet
from modules.visiontransformer import MaskVisionTransformer

from zsrobust.utils import clip_img_preprocessing as preprocessing
from attack.pgd import attack_pgd
from autoattack import AutoAttack

_tokenizer = _Tokenizer()

def RegLoss(param, k):
    assert k in [1,2]
    param = param.view(-1)
    reg_loss = torch.norm(param, k)
    return reg_loss

def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location='cpu').eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location='cpu')

    design_details = {"trainer": 'CoOp',
                      "vision_depth": 0,
                      "language_depth": 0, "vision_ctx": 0,
                      "language_ctx": 0}
    model = clip.build_model(state_dict or model.state_dict(), design_details)

    return model


CUSTOM_TEMPLATES = {
    "OxfordPets": ["a photo of a {}, a type of pet."],
    "OxfordFlowers": ["a photo of a {}, a type of flower."],
    "FGVCAircraft": ["a photo of a {}, a type of aircraft."],
    "DescribableTextures": ["{} texture."],
    "EuroSAT": ["a centered satellite photo of {}."],
    "StanfordCars": ["a photo of a {}."],
    "Food101": ["a photo of {}, a type of food."],
    "SUN397": ["a photo of a {}."],
    "Caltech101": ["a photo of a {}."],
    "UCF101": ["a photo of a person doing {}."],
    "ImageNet": ["a photo of a {}."],
    "ImageNetSketch": ["a photo of a {}."],
    "ImageNetV2": ["a photo of a {}."],
    "ImageNetA": ["a photo of a {}."],
    "ImageNetR": ["a photo of a {}."],
    "TinyImageNet": ["a photo of a {}."],
}


class TextEncoder(nn.Module):

    def __init__(self, cfg, classnames):
        super().__init__()
        self.cfg = cfg
        self.classnames = classnames
        clip_model = load_clip_to_cpu(cfg)
        if cfg.MASK.PREC == "fp32":
            clip_model.float()
        self.clip_model = clip_model.to('cuda')
        self.dtype = clip_model.dtype

    def forward(self):
        temp = CUSTOM_TEMPLATES[self.cfg.DATASET.NAME]
        text_features = []
        with torch.no_grad():
            for classname in self.classnames:
                classname = classname.replace('_', ' ')
                classname = classname.lower()
                texts = [t.format(classname) for t in temp]
                prompts = torch.cat([clip.tokenize(p) for p in texts])
                prompts = prompts.to('cuda')
                class_embeddings = self.clip_model.encode_text(prompts)
                class_embedding = class_embeddings.mean(dim=0)
                class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
                text_features.append(class_embedding)
        text_features = torch.stack(text_features, dim=0).cuda()
        return text_features


class MASKCLIP(nn.Module):

    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.dtype = clip_model.dtype
        origin_image_encoder = clip_model.visual

        clip_state_dict = clip_model.state_dict()
        vit = "visual.proj" in clip_state_dict
        self.vit = vit
        vision_layers, embed_dim, vision_heads, vision_width, image_resolution, vision_patch_size = self.get_params(vit,
                                                                                                                    clip_state_dict)
        if vit:
            self.image_encoder = MaskVisionTransformer(input_resolution=image_resolution, patch_size=vision_patch_size,
                                                       width=vision_width, layers=vision_layers, heads=vision_heads,
                                                       output_dim=embed_dim,
                                                       mask_init=cfg.MASK.INIT, mask_scale=cfg.MASK.SCALE,
                                                       threshold_fn=cfg.MASK.THRESHOLD_FN, threshold=cfg.MASK.THRESHOLD,
                                                       mask_mlp=cfg.MASK.MASK_MLP)
        else:
            self.image_encoder = MaskModifiedResNet(vision_layers, embed_dim, vision_heads,
                                                    input_resolution=image_resolution, width=vision_width,
                                                    mask_init=cfg.MASK.INIT, mask_scale=cfg.MASK.SCALE,
                                                    threshold_fn=cfg.MASK.THRESHOLD_FN, threshold=cfg.MASK.THRESHOLD)

        self.make_model(origin_image_encoder)
        self.origin_image_encoder = origin_image_encoder

        text_encoder = TextEncoder(cfg, classnames)
        self.text_features = text_encoder()

        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.threshold = cfg.MASK.THRESHOLD

    def get_params(self, vit, state_dict):

        if vit:
            vision_width = state_dict["visual.conv1.weight"].shape[0]
            vision_layers = len(
                [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
            vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
            grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
            image_resolution = vision_patch_size * grid_size
            vision_heads = vision_width // 64
        else:
            counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
                            [1, 2, 3, 4]]
            vision_layers = tuple(counts)
            vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
            output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
            vision_patch_size = None
            assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
            image_resolution = output_width * 32
            vision_heads = vision_width * 32 // 64

        embed_dim = state_dict["text_projection"].shape[1]
        return vision_layers, embed_dim, vision_heads, vision_width, image_resolution, vision_patch_size

    def make_model(self, origin_image_encoder):
        """Creates the model."""

        if self.vit:
            self.image_encoder.class_embedding.data.copy_(origin_image_encoder.class_embedding.data)
            self.image_encoder.positional_embedding.data.copy_(origin_image_encoder.positional_embedding.data)
            self.image_encoder.proj.data.copy_(origin_image_encoder.proj.data)

        # Copy weights from the pretrained to the modified model.
        for module, module_pretrained in zip(self.image_encoder.modules(), origin_image_encoder.modules()):
            if 'MultiheadAttention' in str(type(module)):
                module.in_proj_weight.data.copy_(module_pretrained.in_proj_weight.data)
                if module.in_proj_bias is not None:
                    module.in_proj_bias.data.copy_(module_pretrained.in_proj_bias.data)

                module.out_proj.weight.data.copy_(module_pretrained.out_proj.weight.data)
                if module.out_proj.bias is not None:
                    module.out_proj.bias.data.copy_(module_pretrained.out_proj.bias.data)
            elif 'ElementWise' in str(type(module)):
                module.weight.data.copy_(module_pretrained.weight.data)
                if module.bias is not None:
                    module.bias.data.copy_(module_pretrained.bias.data)
            elif 'Linear' in str(type(module)) or 'Conv2d' in str(type(module)):
                module.weight.data.copy_(module_pretrained.weight.data)
                if module.bias is not None:
                    module.bias.data.copy_(module_pretrained.bias.data)
            elif 'BatchNorm' in str(type(module)):
                module.weight.data.copy_(module_pretrained.weight.data)
                module.bias.data.copy_(module_pretrained.bias.data)
                module.running_mean.copy_(module_pretrained.running_mean)
                module.running_var.copy_(module_pretrained.running_var)
            elif 'LayerNorm' in str(type(module)):
                module.weight.data.copy_(module_pretrained.weight.data)
                module.bias.data.copy_(module_pretrained.bias.data)
            elif 'MaskAttentionPool' in str(type(module)):
                module.positional_embedding.data.copy_(module_pretrained.positional_embedding.data)
                for attnpool_module, attnpool_module_pretrained in zip(self.image_encoder.attnpool.modules(),
                                                                       origin_image_encoder.attnpool.modules()):
                    if 'ElementWise' in str(type(attnpool_module)):
                        attnpool_module.weight.data.copy_(attnpool_module_pretrained.weight.data)
                        if attnpool_module.bias is not None:
                            attnpool_module.bias.data.copy_(attnpool_module_pretrained.bias.data)

        print('Creating model: Mask layers created.')

        self.shared = nn.Sequential()
        for name, module in self.image_encoder.named_children():
            self.shared.add_module(name, module)

    def forward(self, image, return_features=False):
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        if self.shared.training and return_features:
            return logits, image_features
        else:
            return logits

    def original_forward(self, image):
        image_features = self.origin_image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        return logits

    def compute_sparsity(self, threshold_fn):
        total_zeros = 0.0
        total_param = 0.0

        weights = self.shared.state_dict()
        for k in list(weights.keys()):
            if 'mask_real' not in k:
                continue
            if threshold_fn == 'binarizer':
                # threshold = weights[f'{".".join(k.split(".")[:-1])}.threshold'].item()
                # num_zero = weights[k].lt(threshold).sum()
                num_zero = weights[k].lt(self.threshold).sum()
                num_param = weights[k].data.numel()

            total_param += num_param
            total_zeros += num_zero

        return (total_zeros) / total_param * 100.


class KLLoss(_Loss):
    def __init__(self, T, alpha=1.):
        super(KLLoss, self).__init__()
        self.T = T
        self.alpha = alpha

    def forward(self, stu_logits, tea_logits, label):
        tea_logits = self.alpha * tea_logits + (1 - self.alpha) * stu_logits

        tea_prob = F.softmax(tea_logits / self.T, dim=-1)
        kl_loss = -tea_prob * F.log_softmax(stu_logits / self.T,
                                            -1) * self.T * self.T
        kl_loss = kl_loss.sum(1).mean()

        return kl_loss

# def js_divergence(p_output, q_output):
#     """Compute Jensen-Shannon Divergence"""
#     M = 0.5 * (p_output + q_output)
#     js = 0.5 * (F.kl_div(M.log(), p_output, reduction='none') +
#                 F.kl_div(M.log(), q_output, reduction='none')).mean(dim=-1)
#     return js

def calculate_elementwise_kl_div(output1, output2):
    instance_wise_kl=F.kl_div(F.log_softmax(output1, dim=1), F.softmax(output2, dim=1), reduction='none')
    kl_divs=torch.sum(instance_wise_kl,dim=-1)
    return kl_divs

def calculate_cosine_similarity(features1, features2):
    return F.cosine_similarity(features1, features2)+1

def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def entropy(p, eps=1e-8):
    return -(p * (p + eps).log()).sum(dim=1)

def calculate_js_loss(output_clean,
                      output_adv,
                      clean_image_features,
                      adv_image_features,
                      label,
                      loss_fn,
                      adv_term="cos",
                      eps=1e-8,
                      tau=None):
    # JS divergence
    # prob_true = torch.zeros_like(output_clean).scatter_(1, label.unsqueeze(1), 1)
    # prob_clean = F.softmax(output_clean, dim=1)
    # prob_adv = F.softmax(output_adv, dim=1)
    #
    # # m = (prob_true + prob_clean + prob_adv) / 3.0
    # # kl_tm = F.kl_div(m.log(), prob_true, reduction='batchmean')
    # # kl_cm = F.kl_div(m.log(), prob_clean, reduction='batchmean')
    # # kl_am = F.kl_div(m.log(), prob_adv, reduction='batchmean')
    # # loss = (kl_tm + kl_cm + kl_am) / 3.0
    #
    # # with torch.no_grad():
    # #     ent_true = entropy(prob_true)  # nearly zero
    # #     ent_clean = entropy(prob_clean)
    # #     ent_adv = entropy(prob_adv)
    # #     entropies = torch.stack((ent_true, ent_clean, ent_adv), dim=-1)
    # #     tau = 10000.0  # larger -> uniform
    # #     coef = F.softmax(-entropies / tau, dim=-1)
    # #     c_t, c_c, c_a = coef[:, 0][:, None], coef[:, 1][:, None], coef[:, 2][:, None]
    #
    # if tau is None:
    #     c_t, c_c, c_a = (1 / 3, 1 / 3, 1 / 3)
    # else:
    #     eps = 1e-7  # prevent underflow
    #     with torch.no_grad():
    #         kld_true = F.kl_div((prob_true + eps).log(), prob_true, reduction='none').sum(dim=-1)  # nearly zero
    #         kld_clean = F.kl_div((prob_true + eps).log(), prob_clean, reduction='none').sum(dim=-1)
    #         kld_adv = F.kl_div((prob_true + eps).log(), prob_adv, reduction='none').sum(dim=-1)
    #         klds = torch.stack((kld_true, kld_clean, kld_adv), dim=-1)
    #         coef = F.softmax(-klds / tau, dim=-1)
    #         c_t, c_c, c_a = coef[:, 0][:, None], coef[:, 1][:, None], coef[:, 2][:, None]
    #
    # m = c_t * prob_true + c_c * prob_clean + c_a * prob_adv
    # # kl_tm = F.kl_div(m.log(), prob_true, reduction='batchmean')
    # # kl_cm = F.kl_div(m.log(), prob_clean, reduction='batchmean')
    # # kl_am = F.kl_div(m.log(), prob_adv, reduction='batchmean')
    # kl_tm = F.kl_div(m.log(), prob_true, reduction='none').sum(dim=-1, keepdim=True)
    # kl_cm = F.kl_div(m.log(), prob_clean, reduction='none').sum(dim=-1, keepdim=True)
    # kl_am = F.kl_div(m.log(), prob_adv, reduction='none').sum(dim=-1, keepdim=True)
    #
    # loss = c_t * kl_tm + c_c * kl_cm + c_a * kl_am
    # loss = loss.mean()

    # only use two probs
    prob_clean = F.softmax(output_clean, dim=1)
    prob_adv = F.softmax(output_adv, dim=1)
    c_c, c_a = 0.5, 0.5
    m = c_c * prob_clean + c_a * prob_adv
    kl_cm = F.kl_div(m.log(), prob_clean, reduction='none').sum(dim=-1, keepdim=True)
    kl_am = F.kl_div(m.log(), prob_adv, reduction='none').sum(dim=-1, keepdim=True)
    loss = c_c * kl_cm + c_a * kl_am
    loss = loss.mean()
    return loss

def calculate_adv_loss(output_clean,
                       output_adv,
                       clean_image_features,
                       adv_image_features,
                       label,
                       loss_fn,
                       adv_term="cos",
                       eps=1e-8,
                       tau=None):
    loss = None
    if loss_fn in ["tecoa", "tecoa_only"]:
        # case 1: TeCoA
        loss = F.cross_entropy(output_adv, label)
    elif loss_fn == "kl":
        # case 2: FAP (L_KL)
        kl_divs = calculate_elementwise_kl_div(output_adv, output_clean)
        loss = torch.mean(kl_divs)
    elif loss_fn == "cos":
        # case 3: FAP (L_COS)
        if adv_term=="cos":
            cosine_sims = calculate_cosine_similarity(clean_image_features, adv_image_features)
            loss = torch.mean(cosine_sims)
        else:
            raise NotImplementedError
    elif loss_fn == "fap":
        # case 4: FAP (L_KL + L_COS)
        kl_divs = calculate_elementwise_kl_div(output_adv, output_clean)
        if adv_term=="cos":
            cosine_sims = calculate_cosine_similarity(clean_image_features, adv_image_features)
            loss = torch.mean(kl_divs * cosine_sims)
        else:
            raise NotImplementedError

    # case 5: img embedding distance (cos. sim)
    # dist = 1.0 - F.cosine_similarity(clean_image_features, adv_image_features)
    # dist = F.mse_loss(clean_image_features, adv_image_features, reduction='sum')
    # loss = torch.mean(dist)
    return loss

def calculate_alignment_loss(clean_layer_feats, adv_layer_feats, output_clean, labels, l0=9, l1=11):
    # sample-wise feature reliability
    with torch.no_grad():
        prob_clean = F.softmax(output_clean, dim=-1)
        reliability = prob_clean.gather(dim=-1, index=labels.view(-1, 1)).squeeze()  # (batch,)
        reliability = reliability / (reliability.mean() + 1e-8)
        # threshold = 1.0/prob_clean.size(-1)
        # filter = reliability > threshold
        # coef = (reliability-threshold).exp()

    layer_idx = [i for i in range(l0, l1+1)]
    # layer_idx = [0, 1, 2]
    # loss = F.mse_loss(adv_layer_feats, clean_layer_feats, reduction='mean')
    # loss = F.mse_loss(adv_layer_feats[layer_idx], clean_layer_feats[layer_idx], reduction='mean')

    loss = F.mse_loss(adv_layer_feats[layer_idx], clean_layer_feats[layer_idx], reduction='none').mean(dim=[0,1,3])
    # if filter.any():
    #     loss = (loss * coef)[filter].mean()
    # else:
    #     loss = torch.tensor(0.0, device=loss.device)
    loss = (reliability * loss).mean()

    # loss = 1.0 - F.cosine_similarity(clean_layer_feats, adv_layer_feats, dim=-1).mean(dim=[1,2])
    # loss = loss.mean()  # average over layers
    return loss

@TRAINER_REGISTRY.register()
class AdvMaskTuning(TrainerX):

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)

        if cfg.ADVMASK.PREC in ["fp32", "amp"]:
            clip_model.float()  # TODO: what if we want 'fp16' precision? => gradient underflow occured

        print("Building custom CLIP")
        self.model = MASKCLIP(cfg, classnames, clip_model)
        if cfg.ADVMASK.PREC in ["fp16"]:
            self.model.image_encoder.proj.data = self.model.image_encoder.proj.data.half()
            self.model.text_features.data = self.model.text_features.data.half()
            for name, module in self.model.shared.named_modules():
                if module.__class__.__name__ in ['ElementWiseConv2d', 'ElementWiseMultiheadAttention']:
                    module.half()
                if 'transformer' in name and module.__class__.__name__ is 'Linear':
                    module.half()


        # ADD NORMALIZE
        self.preprocessing = preprocessing

        for name, param in self.model.shared.named_parameters():
            param.requires_grad_(False)
            if 'mask_real' in name:
                param.requires_grad_(True)

        # # Double check
        enabled = set()
        for name, param in self.model.shared.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        param_groups = [
            {
                "params": self.model.shared.parameters(),
            },
        ]

        trainable_param = sum(p.numel() for p in self.model.shared.parameters() if p.requires_grad)

        self.model.to(self.device)
        # NOTE: only give mask to the optimizer
        self.optim = build_optimizer(None, cfg.OPTIM, param_groups)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)

        self.register_model("shared", self.model.shared, self.optim, self.sched)

        self.scaler = GradScaler() if cfg.ADVMASK.PREC == "amp" else None

        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        device_count = torch.cuda.device_count()
        if device_count > 1:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)
        self.model.eval()
        # self.kl_loss = KLLoss(T=self.cfg.MASK.GDR_T, alpha=self.cfg.MASK.FUSE_ALPHA)
        self.count = []

        ## build clean mask model
        self.model2 = None

        if cfg.MODEL.INIT_WEIGHTS:
            self.load_model(cfg.MODEL.INIT_WEIGHTS, cfg.OPTIM.MAX_EPOCH, init_weights=True)


    # def build_model(self):  # for Threshold Tuning
    #     cfg = self.cfg
    #     classnames = self.dm.dataset.classnames
    #
    #     print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
    #     clip_model = load_clip_to_cpu(cfg)
    #
    #     clip_model.float()  # TODO: what if we want 'fp16' precision?
    #
    #     print("Building custom CLIP")
    #     self.model = MASKCLIP(cfg, classnames, clip_model)
    #
    #     # ADD NORMALIZE
    #     self.preprocessing = preprocessing
    #
    #     for name, param in self.model.shared.named_parameters():
    #         param.requires_grad_(False)
    #         if 'mask_real' in name:
    #             param.requires_grad_(True)
    #         if 'threshold' in name:
    #             param.requires_grad_(True)
    #
    #     mask_params = []
    #     thr_params = []
    #
    #     # # Double check
    #     enabled = set()
    #     for name, param in self.model.shared.named_parameters():
    #         if param.requires_grad:
    #             enabled.add(name)
    #             if 'mask_real' in name:
    #                 mask_params.append(param)
    #             elif 'threshold' in name:
    #                 thr_params.append(param)
    #     print(f"Parameters to be updated: {enabled}")
    #
    #     param_groups1 = [
    #         {
    #             "params": mask_params, "lr": cfg.OPTIM.LR,
    #         },
    #         {
    #             "params": thr_params, "lr": 4e-5,
    #         }
    #     ]
    #
    #     trainable_param = sum(p.numel() for p in self.model.shared.parameters() if p.requires_grad)
    #
    #     self.model.to(self.device)
    #     # NOTE: only give mask to the optimizer
    #     self.optim = build_optimizer(None, cfg.OPTIM, param_groups1)
    #     self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
    #
    #     # ################################
    #     # ## optmizer/scheduler for mask threshold
    #     # param_groups2 = [
    #     #     {
    #     #         "params": thr_params,
    #     #     },
    #     # ]
    #     # self.thr_optim = build_optimizer(None, cfg.OPTIM, param_groups2)
    #     # self.thr_sched = build_lr_scheduler(self.thr_optim, cfg.OPTIM)
    #     # ###################################
    #
    #     self.register_model("shared", self.model.shared, self.optim, self.sched)
    #
    #     self.scaler = GradScaler() if cfg.ADVMASK.PREC == "amp" else None
    #
    #     # Note that multi-gpu training could be slow because CLIP's size is
    #     # big, which slows down the copy operation in DataParallel
    #     device_count = torch.cuda.device_count()
    #     if device_count > 1:
    #         print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
    #         self.model = nn.DataParallel(self.model)
    #     self.model.eval()
    #     # self.kl_loss = KLLoss(T=self.cfg.MASK.GDR_T, alpha=self.cfg.MASK.FUSE_ALPHA)
    #     self.count = []

    def forward_backward(self, batch):
        self.set_model_mode("train")
        image, label = self.parse_batch_train(batch)
        model = self.model
        optim = self.optim
        scaler = self.scaler
        n_iter = self.epoch * self.num_batches + self.batch_idx

        # t0 = time.time()
        prec = self.cfg.ADVMASK.PREC
        if prec == "amp":
            nat_loss = torch.tensor(0.0).to(self.device)
            adv_loss = torch.tensor(0.0).to(self.device)

            with autocast():
                # generate adv perturbation
                delta = attack_pgd(model, self.preprocessing, image, label, alpha=self.cfg.ATTACK.PGD.ALPHA,
                                   attack_iters=self.cfg.ATTACK.PGD.TRAIN_ITER,
                                   epsilon=self.cfg.ATTACK.PGD.EPS,  # adaptive_epsilon, self.cfg.ATTACK.PGD.EPS
                                   train_trades=self.cfg.ATTACK.PGD.ADV_TERM != "ce")

                # model forwarding
                ##### adv forward
                image_adv = self.preprocessing(image + delta)
                output_adv, adv_image_features = model(image_adv, return_features=True)
                adv_layer_feats = torch.stack([
                    res.visual_feat for res in model.image_encoder.transformer.resblocks
                ], dim=0)  # store adv layer-wise features

                ##### clean forward
                image_clean = self.preprocessing(image)
                output_clean, clean_image_features = model(image_clean, return_features=True)
                clean_layer_feats = torch.stack([
                    res.visual_feat.detach() for res in model.image_encoder.transformer.resblocks
                ], dim=0)  # store clean layer-wise features

                # loss
                nat_loss = F.cross_entropy(output_clean, label)

                # if self.cfg.ADVMASK.LOSS_FN == 'ce_only':
                #     loss = nat_loss
                # elif self.cfg.ADVMASK.LOSS_FN == 'tecoa':
                #     a = F.cross_entropy(output_clean, label, reduction='none')
                #     b = F.cross_entropy(output_adv, label, reduction='none')
                #     with torch.no_grad():
                #         prob_clean = F.softmax(output_clean, dim=-1)
                #         ent = -(prob_clean * prob_clean.log()).sum(dim=-1)
                #         lambda_max = self.cfg.ADVMASK.LAMB1
                #         entropy_max = torch.tensor(self.num_classes).log()
                #         lambda_ = lambda_max * (1 - (ent / entropy_max))
                #     loss = (a + lambda_ * b).mean()
                # elif self.cfg.ADVMASK.LOSS_FN == 'js':
                #     adv_loss = F.cross_entropy(output_adv, label)
                #     js_loss = calculate_js_loss(output_clean, output_adv,
                #                              clean_image_features, adv_image_features,
                #                              label,
                #                              loss_fn=self.cfg.ADVMASK.LOSS_FN,
                #                              adv_term="cos",
                #                              tau=self.cfg.ADVMASK.TAU)
                #     loss = adv_loss + self.cfg.ADVMASK.LAMB1 * js_loss

                if self.cfg.ADVMASK.LOSS_FN == 'tecoa':
                    adv_loss = F.cross_entropy(output_adv, label)
                    loss = adv_loss
                elif self.cfg.ADVMASK.LOSS_FN == 'tecoa+js':
                    adv_loss = F.cross_entropy(output_adv, label)
                    js_loss = calculate_js_loss(output_clean, output_adv,
                                             clean_image_features, adv_image_features,
                                             label,
                                             loss_fn=self.cfg.ADVMASK.LOSS_FN,
                                             adv_term="cos",
                                             tau=self.cfg.ADVMASK.TAU)
                    loss = adv_loss + self.cfg.ADVMASK.LAMB1 * js_loss

                elif self.cfg.ADVMASK.LOSS_FN == 'tecoa+kl':
                    adv_loss = F.cross_entropy(output_adv, label)
                    kl_divs = calculate_elementwise_kl_div(output_adv, output_clean)
                    kl_loss = torch.mean(kl_divs)
                    loss = adv_loss + self.cfg.ADVMASK.LAMB1 * kl_loss

                elif self.cfg.ADVMASK.LOSS_FN == 'align':  # tecoa+align
                    adv_loss = F.cross_entropy(output_adv, label)
                    alignment_loss = calculate_alignment_loss(clean_layer_feats,
                                                              adv_layer_feats,
                                                              output_clean,
                                                              label,
                                                              l0=self.cfg.ADVMASK.LAYER0,
                                                              l1=self.cfg.ADVMASK.LAYER1)
                    loss = adv_loss + self.cfg.ADVMASK.LAMB1 * alignment_loss
                elif self.cfg.ADVMASK.LOSS_FN == 'fap':
                    kl_divs = calculate_elementwise_kl_div(output_adv, output_clean)
                    cosine_sims = calculate_cosine_similarity(clean_image_features, adv_image_features)
                    adv_loss = torch.mean(kl_divs * cosine_sims)
                    loss = nat_loss + self.cfg.ADVMASK.LAMB1 * adv_loss
                else: # tecoa, tecoa_only, kl, cos
                    pass
                    # adv_loss = calculate_adv_loss(output_clean, output_adv,
                    #                               clean_image_features, adv_image_features,
                    #                               label,
                    #                               loss_fn=self.cfg.ADVMASK.LOSS_FN,
                    #                               adv_term="cos",
                    #                               tau=self.cfg.ADVMASK.TAU)
                    # if self.cfg.ADVMASK.LOSS_FN == 'tecoa_only':
                    #     loss = adv_loss
                    # else:
                    #     loss = nat_loss + self.cfg.ADVMASK.LAMB1 * adv_loss

                if "ce" in self.cfg.ADVMASK.LOSS_FN:  # ce_only, ce+####
                    if self.cfg.ADVMASK.LOSS_FN == 'ce_only':
                        loss = torch.tensor(0.0).to(self.device)
                    elif self.cfg.ADVMASK.LOSS_FN == 'ce+tecoa':
                        adv_loss = F.cross_entropy(output_adv, label)
                        loss = adv_loss
                    elif self.cfg.ADVMASK.LOSS_FN == 'ce+tecoa+js':
                        adv_loss = F.cross_entropy(output_adv, label)
                        js_loss = calculate_js_loss(output_clean, output_adv,
                                                    clean_image_features, adv_image_features,
                                                    label,
                                                    loss_fn=self.cfg.ADVMASK.LOSS_FN,
                                                    adv_term="cos",
                                                    tau=self.cfg.ADVMASK.TAU)
                        loss = adv_loss + self.cfg.ADVMASK.LAMB1 * js_loss

                    elif self.cfg.ADVMASK.LOSS_FN == 'ce+tecoa+kl':
                        adv_loss = F.cross_entropy(output_adv, label)
                        kl_divs = calculate_elementwise_kl_div(output_adv, output_clean)
                        kl_loss = torch.mean(kl_divs)
                        loss = adv_loss + self.cfg.ADVMASK.LAMB1 * kl_loss

                    elif self.cfg.ADVMASK.LOSS_FN == 'ce+align':  # tecoa+align
                        adv_loss = F.cross_entropy(output_adv, label)
                        alignment_loss = calculate_alignment_loss(clean_layer_feats,
                                                                  adv_layer_feats,
                                                                  output_clean,
                                                                  label,
                                                                  l0=self.cfg.ADVMASK.LAYER0,
                                                                  l1=self.cfg.ADVMASK.LAYER1)
                        loss = adv_loss + self.cfg.ADVMASK.LAMB1 * alignment_loss
                    else:
                        raise NotImplementedError

                    loss += nat_loss

            self.model_zero_grad(None)
            names = self.get_model_names(None)
            self.detect_anomaly(loss)
            scaler.scale(loss).backward()
            for name in names:
                scaler.step(self._optims[name])
            scaler.update()

        else:
            raise NotImplementedError

            nat_loss = torch.tensor(0.0).to(self.device)
            adv_loss = torch.tensor(0.0).to(self.device)

            # # eps scheduling
            # epsilon_min = 0.0
            # epsilon_max = self.cfg.ATTACK.PGD.EPS
            # # adaptive_epsilon = epsilon_min + (epsilon_max - epsilon_min) * (self.epoch / self.max_epoch)
            # import math
            # alpha = self.cfg.ADVMASK.RAMPUP_ALPHA
            # progress_ratio = 1 - math.exp(-alpha * self.epoch / self.max_epoch)
            # adaptive_epsilon = epsilon_min + (epsilon_max - epsilon_min) * progress_ratio

            # generate adv perturbation
            delta = attack_pgd(model, self.preprocessing, image, label, alpha=self.cfg.ATTACK.PGD.ALPHA,
                               attack_iters=self.cfg.ATTACK.PGD.TRAIN_ITER,
                               epsilon=self.cfg.ATTACK.PGD.EPS,  # adaptive_epsilon, self.cfg.ATTACK.PGD.EPS
                               train_trades=self.cfg.ATTACK.PGD.ADV_TERM != "ce")

            # ########## forward hook ###################
            # activations = {}
            # def get_node_out(name):
            #     def hook(module, input, output):
            #         cls_embed = output.permute(1, 0, 2)[:, 0, :]
            #         if name in activations:
            #             activations[name].append(cls_embed)
            #         else:
            #             activations[name] = [cls_embed]
            #         return output
            #     return hook
            #
            # hooks = {}
            # # idx_list = [0,1,2,3,4,5,6,7,8,9,10,11]
            # idx_list = [0, 1, 2]
            # for idx, (name, module) in enumerate(model.image_encoder.transformer.resblocks.named_children()):
            #     if idx in idx_list:
            #         hooks[name] = module.register_forward_hook(get_node_out(name))
            # ##########################################################


            # model forwarding

            ##### adv forward
            image_adv = self.preprocessing(image + delta)
            output_adv, adv_image_features = model(image_adv, return_features=True)
            # adv_layer_feats = torch.stack([
            #     res.visual_feat for res in model.image_encoder.transformer.resblocks
            # ], dim=0)  # store adv layer-wise features

            ##### clean forward
            image_clean = self.preprocessing(image)
            output_clean, clean_image_features = model(image_clean, return_features=True)
            # clean_layer_feats = torch.stack([
            #     res.visual_feat.detach() for res in model.image_encoder.transformer.resblocks
            # ], dim=0)  # store clean layer-wise features

            # loss
            # nat_loss = F.cross_entropy(output_clean, label)
            # adv_loss = calculate_adv_loss(output_clean, output_adv,
            #                               clean_image_features, adv_image_features,
            #                               label,
            #                               adv_term="cos",
            #                               tau=self.cfg.ADVMASK.TAU)
            # ################ alignment loss ##################
            # alignment_loss = calculate_alignment_loss(clean_layer_feats, adv_layer_feats)
            # ###################################################

            ################ Overall Loss #####################
            # # lambda_1 scheduling
            # l1_min = 0.0
            # l1_max = self.cfg.ADVMASK.LAMB1
            # adaptive_l1 = l1_min + (l1_max - l1_min) * (self.epoch / self.max_epoch)
            # # import math
            # # alpha = self.cfg.ADVMASK.RAMPUP_ALPHA
            # # progress_ratio = 1 - math.exp(-alpha * self.epoch / self.max_epoch)
            # # adaptive_l1 = l1_min + (l1_max - l1_min) * progress_ratio
            # # loss = nat_loss + adaptive_l1 * adv_loss
            # loss = nat_loss + adaptive_l1 * adv_loss + self.cfg.ADVMASK.LAMB2 * alignment_loss

            # # relative coefficient scheduling
            # mode = 'exponential'  # exponential, polynomial
            # alpha0 = 1.0
            # if mode == 'exponential':
            #     beta = 0.7
            #     alpha = alpha0 * (beta ** self.epoch)
            # elif mode == 'polynomial':
            #     p = self.cfg.ADVMASK.POLY_P  # linear scheduling (p=1)
            #     ratio = alpha0 - self.epoch / self.max_epoch
            #     alpha = ratio ** p
            # else:
            #     raise NotImplementedError
            # loss = alpha * nat_loss + (1 - alpha) * adv_loss

            # ## multi-stage mask tuning
            # if self.epoch < self.cfg.ADVMASK.TURN_POINT:
            #     loss = nat_loss
            # else:
            #     loss = nat_loss + self.cfg.ADVMASK.LAMB1 * adv_loss

            # loss = nat_loss + self.cfg.ADVMASK.LAMB1 * adv_loss + self.cfg.ADVMASK.LAMB2 * alignment_loss
            # loss = nat_loss + self.cfg.ADVMASK.LAMB1 * adv_loss
            # loss = nat_loss
            # loss = adv_loss

            # loss ##############################
            nat_loss = F.cross_entropy(output_clean, label)

            if self.cfg.ADVMASK.LOSS_FN == 'ce_only':
                loss = nat_loss
            elif self.cfg.ADVMASK.LOSS_FN == 'js':
                loss = calculate_js_loss(output_clean, output_adv,
                                         clean_image_features, adv_image_features,
                                         label,
                                         loss_fn=self.cfg.ADVMASK.LOSS_FN,
                                         adv_term="cos",
                                         tau=self.cfg.ADVMASK.TAU)
            else:  # tecoa, tecoa_only, kl, cos
                adv_loss = calculate_adv_loss(output_clean, output_adv,
                                              clean_image_features, adv_image_features,
                                              label,
                                              loss_fn=self.cfg.ADVMASK.LOSS_FN,
                                              adv_term="cos",
                                              tau=self.cfg.ADVMASK.TAU)
                if self.cfg.ADVMASK.LOSS_FN == 'tecoa_only':
                    loss = adv_loss
                else:
                    loss = nat_loss + self.cfg.ADVMASK.LAMB1 * adv_loss

            self.graddrop_backward_and_update(loss)
            ###################################################

        loss_summary = {"main_loss": loss.item(),
                        "nat_loss": nat_loss.item(),
                        "adv_loss": adv_loss.item(),
                        }

        # ##########################################################################
        # # Mask regularization loss (l2 or l1 loss)
        # MASK_LOSS = self.cfg.ADVMASK.REG
        # if MASK_LOSS:
        #     reg_loss = 0
        #     for name, param in model.shared.named_parameters():
        #         if 'mask_real' in name:
        #             reg_loss += RegLoss(param, 2)
        #     mask_loss = 1e-6 * reg_loss
        #     loss += mask_loss
        #     loss_summary["mask loss"] = mask_loss.item()
        # ##########################################################################

        # tea_logits = model.original_forward(image)
        # kl_loss = self.kl_loss(logits, tea_logits, label)
        # loss_summary["kl loss"] = kl_loss.item()
        # self.graddrop_backward_and_update(loss)


        # ########## Multi-objective Learning #############
        # norm1, norm2 = self.multi_objective_backward_and_update(
        #     nat_loss, adv_loss, lambda_=1, names=None, epsilon=1e-8
        # )
        # loss_summary["norm_clean"] = norm1.item()
        # loss_summary["norm_adv"] = norm2.item()

        # thrs = [round(thr.item(), 5) for thr in self.optim.param_groups[1]['params']]
        # for idx, thr in enumerate(thrs):
        #     loss_summary[f"thr{idx}"] = thr

        # self.graddrop_backward_and_update(loss)
        #################################################

        logit_scale = model.logit_scale.exp()
        logits_clean = logit_scale * clean_image_features @ model.text_features.t()
        loss_summary["clean_acc"] = compute_accuracy(logits_clean, label)[0].item()
        logits_adv = logit_scale * adv_image_features @ model.text_features.t()
        loss_summary["adv_acc"] = compute_accuracy(logits_adv, label)[0].item()

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        # if (self.batch_idx + 1) == self.num_batches and self.epoch >= self.cfg.ADVMASK.TURN_POINT:
        #     self.update_lr()

        return loss_summary

    def graddrop_backward_and_update(self, loss, names=None):
        # print('=================use grad drop===============')

        self.model_zero_grad(names)
        # get name of the model parameters
        names = self.get_model_names(names)

        # backward loss
        self.detect_anomaly(loss)
        loss.backward()

        # optimizer
        for name in names:
            self._optims[name].step()

    def multi_objective_backward_and_update(
            self, loss_a, loss_b, lambda_=1, names=None, epsilon=1e-8
    ):

        # loss_b not increase is okay
        # loss_a has to decline
        self.model_zero_grad(names)
        # get name of the model parameters
        names = self.get_model_names(names)
        # backward loss_a
        self.detect_anomaly(loss_b)
        loss_b.backward(retain_graph=True)
        # normalize gradient
        b_grads = []
        for name in names:
            for p in self._models[name].parameters():
                if p.grad is not None:
                    b_grads.append(p.grad.clone())
                else:
                    b_grads.append(None)

        # optimizer don't step
        for name in names:
            self._optims[name].zero_grad()


        # backward loss_a
        a_grads_norm = []
        b_grads_norm = []
        self.detect_anomaly(loss_a)
        loss_a.backward()
        for name in names:
            for p, b_grad in zip(self._models[name].parameters(), b_grads):
                if b_grad is not None:
                    a_grad = p.grad.clone()
                    a_norm = a_grad.norm(p=2)
                    b_norm = b_grad.norm(p=2)
                    a_grads_norm.append(a_norm)
                    b_grads_norm.append(b_norm)

                    # normal training
                    p.grad = a_grad + b_grad

                    # ##### Layer-wise Grad Scaling
                    # norm_mean = (a_norm + b_norm)/2.0
                    # a_grad_scale = a_grad * norm_mean / a_norm
                    # # a_grad_scale2 = a_grad * b_norm / a_norm
                    # b_grad_scale = b_grad * norm_mean / b_norm
                    # p.grad = a_grad_scale + b_grad_scale
                    # # p.grad = a_grad_scale2 + b_grad
                    # ###############################

                    # ################################### PCGrad #########################
                    # g_i = a_grad.flatten()
                    # g_j = b_grad.flatten()
                    #
                    # g_i_g_j = torch.dot(g_i, g_j)
                    # if g_i_g_j < 0:
                    #     g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)
                    #     g_i_new = g_i.view(a_grad.shape)
                    #
                    #     g_j -= g_i_g_j * g_i / (g_i.norm() ** 2)
                    #     g_j_new = g_j.view(b_grad.shape)
                    #
                    #     p.grad = (g_i_new + g_j_new) * 0.5
                    #
                    # # Note: 두 grad를 단순합/평균한 것과 loss를 weighted sum한 것은 크게 차이가 없었음.
                    # p.grad = (a_grad + b_grad) * 0.5
                    # ####################################################################

        # optimizer
        for name in names:
            self._optims[name].step()

        # # check gradient of loss_a
        # self.detect_anomaly(loss_a)
        # loss_a.backward(retain_graph=True)
        # a_grads_norm = []
        # b_grads_norm = []
        # for name in names:
        #     for p, b_grad in zip(self._models[name].parameters(), b_grads):
        #         if p.grad is not None:
        #             assert b_grad is not None
        #             a_grads_norm.append(p.grad.clone().norm(p=2))
        #             b_grads_norm.append(b_grad.norm(p=2))
        #
        # # optimizer don't step
        # for name in names:
        #     self._optims[name].zero_grad()

        return torch.stack(a_grads_norm).mean(), torch.stack(b_grads_norm).mean()



    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    def before_train(self):
        super().before_train()

        # calculate init sparsity
        sparsity = self.model.compute_sparsity(self.cfg.MASK.THRESHOLD_FN)
        print("++++++++++++ Init Sparsity: ", sparsity, "++++++++++++")

    def load_model(self, directory, epoch=None, init_weights=False):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # By default, the last model is loaded
        model_file = "model.pth.tar-" + str(self.cfg.OPTIM.MAX_EPOCH)

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))

            if 'shared' in name:
                def decode(mask, int_value):
                    # decode binary values from bytes
                    shape = mask.view(-1).shape[0]

                    bin_str = ''.join('{:08b}'.format(c) for c in int_value)
                    bin_str = bin_str[:shape]
                    decoded_mask = torch.FloatTensor([int(bin_str[i]) for i in range(len(bin_str))])
                    decoded_mask = decoded_mask.reshape_as(mask)
                    decoded_mask = decoded_mask.to(mask.device)
                    return decoded_mask

                # decode mask from bytes
                model_state_dict = self._models[name].state_dict()
                for key in model_state_dict.keys():
                    # pass
                    if 'mask_real' in key:
                        mask = decode(model_state_dict[key], state_dict[key])
                        if init_weights:
                            mask[mask == 1.] = self.cfg.MASK.SCALE
                        model_state_dict[key] = mask.data  # .cpu()
                state_dict = model_state_dict

            self._models[name].load_state_dict(state_dict, strict=True)
            sparsity = self.model.compute_sparsity(self.cfg.MASK.THRESHOLD_FN)

    def load_2nd_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()
        assert len(names) == 1 and names[0] == 'shared'  # only consider this case

        # By default, the last model is loaded
        model_file = "model.pth.tar-" + str(self.cfg.OPTIM.MAX_EPOCH)

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))

            if 'shared' in name:
                def decode(mask, int_value):
                    # decode binary values from bytes
                    shape = mask.view(-1).shape[0]

                    bin_str = ''.join('{:08b}'.format(c) for c in int_value)
                    bin_str = bin_str[:shape]
                    decoded_mask = torch.FloatTensor([int(bin_str[i]) for i in range(len(bin_str))])
                    decoded_mask = decoded_mask.reshape_as(mask)
                    decoded_mask = decoded_mask.to(mask.device)
                    return decoded_mask

                # decode mask from bytes
                model_state_dict = self.model2.shared.state_dict()
                for key in model_state_dict.keys():
                    # pass
                    if 'mask_real' in key:
                        mask = decode(model_state_dict[key], state_dict[key])
                        model_state_dict[key] = mask.data  # .cpu()
                state_dict = model_state_dict

            self.model2.shared.load_state_dict(state_dict, strict=True)
            sparsity = self.model2.compute_sparsity(self.cfg.MASK.THRESHOLD_FN)

    def set_model_mode(self, mode="train", names=None):
        names = self.get_model_names(names)

        for name in names:
            if mode == "train":
                self._models[name].train()
                for module in self._models[name].modules():
                    if 'BatchNorm' in str(type(module)):
                        module.eval()
            elif mode in ["test", "eval"]:
                self._models[name].eval()
            else:
                raise KeyError

    @torch.no_grad()
    def test(self, split=None, during_train=False):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        ###########################
        perform_adv_test = True
        self.evaluator_adv = build_evaluator(self.cfg, lab2cname=self.lab2cname)
        self.evaluator_adv.reset()
        torch.cuda.empty_cache()
        ###########################

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
        else:
            split = "test"  # in case val_loader is None
            data_loader = self.test_loader

        print(f"Evaluate on the *{split}* set (Natural + Adv)")

        for batch_idx, batch in enumerate(tqdm(data_loader)):

            # nature test
            with torch.no_grad():
                input, label = self.parse_batch_test(batch)
                output = self.model_inference(self.preprocessing(input))
                self.evaluator.process(output, label)

            ########################################################
            # adv test
            if perform_adv_test:
                torch.cuda.empty_cache()

                if self.cfg.ATTACK.TEST == 'aa':  # autoattack
                    from attack.auto import attack_auto
                    # ############### For debugging-mode, turn this on #################
                    # import autoattack.checks
                    # def patched_check_dynamic(model, x, is_tf_model=False, logger=None):
                    #     print("[Info] check_dynamic skipped to avoid debugger conflict.")
                    # autoattack.checks.check_dynamic = patched_check_dynamic
                    # ######################################################################
                    eps = self.cfg.ATTACK.AA.EPS/255.
                    input_adv = attack_auto(self.model_inference_with_normalization, input, label,
                                            text_tokens=None, prompter=None, add_prompter=None,
                                            device=input.device, attacks_to_run=['apgd-ce', 'apgd-dlr'],
                                            epsilon=eps)
                    tmp = self.preprocessing(input_adv)
                    # tmp = input_adv
                else:
                    delta = attack_pgd(self.model_inference, self.preprocessing, input, label, alpha=self.cfg.ATTACK.PGD.ALPHA,
                                    attack_iters=self.cfg.ATTACK.PGD.TEST_ITER, epsilon=self.cfg.ATTACK.PGD.EPS)
                    tmp = self.preprocessing(input + delta)

                torch.cuda.empty_cache()
                with torch.no_grad():
                    output_adv = self.model_inference(tmp)
                    self.evaluator_adv.process(output_adv, label)
            ########################################################

        results = self.evaluator.evaluate()
        if during_train: return list(results.values())[0]

        sparsity = self.model.compute_sparsity(self.cfg.MASK.THRESHOLD_FN)
        print("++++++++++++ Sparsity: ", sparsity, "++++++++++++")

        for k, v in results.items():
            tag = f"{split}/{k}"
            self.write_scalar(tag, v, self.epoch)

        ####################################################
        if perform_adv_test:
            results_adv = self.evaluator_adv.evaluate()
            for k, v in results_adv.items():
                tag = f"{split}/{k}_adv"
                self.write_scalar(tag, v, self.epoch)

            return list(results.values())[0], list(results_adv.values())[0]
        ####################################################

        return list(results.values())[0]

    def model_inference_with_normalization(self, input):
        input = self.preprocessing(input)
        return self.model(input)

    @torch.no_grad()
    def test_dual(self, split=None, during_train=False):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        ###########################
        perform_adv_test = True
        self.evaluator_adv = build_evaluator(self.cfg, lab2cname=self.lab2cname)
        self.evaluator_adv.reset()
        torch.cuda.empty_cache()
        ###########################

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
        else:
            split = "test"  # in case val_loader is None
            data_loader = self.test_loader

        print(f"Evaluate on the *{split}* set (Natural + Adv)")

        for batch_idx, batch in enumerate(tqdm(data_loader)):

            # nature test
            with torch.no_grad():
                input, label = self.parse_batch_test(batch)
                output = self.dual_mask_inference(self.preprocessing(input))
                self.evaluator.process(output, label)

            ########################################################
            # adv test
            if perform_adv_test:
                torch.cuda.empty_cache()

                delta = attack_pgd(self.dual_mask_inference, self.preprocessing, input, label,
                                   alpha=self.cfg.ATTACK.PGD.ALPHA,
                                   attack_iters=self.cfg.ATTACK.PGD.TEST_ITER, epsilon=self.cfg.ATTACK.PGD.EPS)
                tmp = self.preprocessing(input + delta)

                torch.cuda.empty_cache()
                with torch.no_grad():
                    output_adv = self.dual_mask_inference(tmp)
                    self.evaluator_adv.process(output_adv, label)
            ########################################################

        results = self.evaluator.evaluate()
        if during_train: return list(results.values())[0]

        sparsity = self.model.compute_sparsity(self.cfg.MASK.THRESHOLD_FN)
        print("++++++++++++ Sparsity: ", sparsity, "++++++++++++")

        for k, v in results.items():
            tag = f"{split}/{k}"
            self.write_scalar(tag, v, self.epoch)

        ####################################################
        if perform_adv_test:
            results_adv = self.evaluator_adv.evaluate()
            for k, v in results_adv.items():
                tag = f"{split}/{k}_adv"
                self.write_scalar(tag, v, self.epoch)

            return list(results.values())[0], list(results_adv.values())[0]
        ####################################################

        return list(results.values())[0]

    def dual_mask_inference(self, input, tau=0.004, beta=100.0):
        output_clean = self.model2(input)
        output_adv = self.model(input)

        # # Forward pass with clean mask
        # with torch.no_grad():
        #     output_clean = self.model2(input)
        #     prob_clean = F.softmax(output_clean, dim=1)
        #
        # # Forward pass with adv mask
        # with torch.no_grad():
        #     output_adv = self.model(input)
        #     prob_adv = F.softmax(output_adv, dim=1)

        # ## Dual-Mask Output Divergence
        # # Calculate JS divergence
        # js_scores = js_divergence(prob_clean, prob_adv)
        # # Determine mask weight (alpha) based on divergence
        # alpha = torch.sigmoid(beta * (js_scores - tau)).unsqueeze(dim=-1)

        ## Naive Ensembling
        alpha = 1.0

        weighted_outputs = alpha * output_adv + (1 - alpha) * output_clean
        return weighted_outputs

    def check(self):

        cfg = self.cfg
        classnames = self.dm.dataset.classnames
        clip_model = load_clip_to_cpu(cfg)

        clip_model.float()
        pretrained = MASKCLIP(cfg, classnames, clip_model)
        for module, module_pretrained in zip(self.model.shared.modules(), pretrained.shared.modules()):
            if 'ElementWise' in str(type(module)) or 'BatchNorm' in str(type(module)) or 'LayerNorm' in str(
                    type(module)):
                weight = module.weight.data.cpu()
                weight_pretrained = module_pretrained.weight.data.cpu()
                # Using small threshold of 1e-8 for any floating point inconsistencies.
                # Note that threshold per element is even smaller as the 1e-8 threshold
                # is for sum of absolute differences.
                assert (weight - weight_pretrained).abs().sum() < 1e-8, \
                    'module %s failed check' % (module)
                if module.bias is not None:
                    bias = module.bias.data.cpu()
                    bias_pretrained = module_pretrained.bias.data.cpu()
                    assert (bias - bias_pretrained).abs().sum() < 1e-8
                if 'BatchNorm' in str(type(module)):
                    rm = module.running_mean.cpu()
                    rm_pretrained = module_pretrained.running_mean.cpu()
                    assert (rm - rm_pretrained).abs().sum() < 1e-8
                    rv = module.running_var.cpu()
                    rv_pretrained = module_pretrained.running_var.cpu()
                    assert (rv - rv_pretrained).abs().sum() < 1e-8

        assert (
                           self.model.image_encoder.attnpool.positional_embedding.data.cpu() - clip_model.visual.attnpool.positional_embedding.data.cpu()).abs().sum() < 1e-8

        print('Passed checks...')

    def after_epoch(self):
        last_epoch = (self.epoch + 1) == self.max_epoch
        do_test = not self.cfg.TEST.NO_TEST
        meet_checkpoint_freq = (
            (self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
            if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
        )

        if last_epoch:
            self.layerwise_mask_change()
            last_result = self.test(split="test")
            if self.cfg.ADVMASK.SAVE_MODEL:
                self.save_model(self.epoch, self.output_dir)

    def save_model(
            self, epoch, directory, is_best=False, val_result=None, model_name=""
    ):
        names = self.get_model_names()

        for name in names:

            model_dict = self._models[name].state_dict()
            if 'shared' in name:
                # save binary values into  bytes
                binarized_model_dict = {}

                def binarized_mask(mask):
                    s = ''.join('%s' % int(m) for m in mask)
                    l = len(s)
                    if r := l % 8:
                        s += '0' * (8 - r)
                    value = bytes([int(s[i:i + 8], 2) for i in range(0, len(s), 8)])
                    return value

                for key in model_dict.keys():
                    if 'mask_real' in key:
                        mask = model_dict[key].clone()
                        mask[model_dict[key].le(self.cfg.MASK.THRESHOLD)] = 0
                        mask[model_dict[key].gt(self.cfg.MASK.THRESHOLD)] = 1
                        mask = mask.view(-1).data.cpu()
                        binarized_model_dict[key] = binarized_mask(mask)
                model_dict = binarized_model_dict

            optim_dict = None
            if self._optims[name] is not None:
                optim_dict = self._optims[name].state_dict()

            sched_dict = None
            if self._scheds[name] is not None:
                sched_dict = self._scheds[name].state_dict()

            save_checkpoint(
                {
                    "state_dict": model_dict,
                    "epoch": epoch + 1,
                    "optimizer": optim_dict,
                    "scheduler": sched_dict,
                    "val_result": val_result
                },
                osp.join(directory, name),
                is_best=is_best,
                model_name=model_name,
            )

    def layerwise_mask_change(self):
        print("#### Analyzing layerwise mask changes after training...\n")

        total_change = 0.0
        total_param = 0.0
        weights = self.model.shared.state_dict()
        layer_changes = {}

        for k in weights.keys():
            if 'mask_real' not in k:
                continue

            current_mask = weights[k]
            init_mask = self.cfg.MASK.SCALE

            l2_change = ((current_mask - init_mask) ** 2).sum().item()
            num_param = current_mask.numel()

            layer_changes[k] = {
                'l2_change': l2_change,
                'num_param': num_param,
                'normalized_change_per_param': (l2_change / num_param) * 1e+5
            }

            total_change += l2_change
            total_param += num_param

        # --- Per-layer summary ---
        print("Layer-wise Mask Change Summary:")
        for layer, stats in layer_changes.items():
            print(f"  - {layer}: "
                  f"L2={stats['l2_change']:.2f}, "
                  f"Per-param Δ(x1e+5)={stats['normalized_change_per_param']:.6f}, "
                  f"Params={stats['num_param']}")

        # --- Group-wise summary ---
        group_types = ['conv1', 'attn', 'mlp']
        group_stats = {g: {'total_l2': 0.0, 'total_param': 0} for g in group_types}

        for layer_name, stats in layer_changes.items():
            for group in group_types:
                if group in layer_name:
                    group_stats[group]['total_l2'] += stats['l2_change']
                    group_stats[group]['total_param'] += stats['num_param']
                    break
            else:
                raise NotImplementedError(f"Unrecognized layer type in key: {layer_name}")

        print("\n=== Group-wise Summary ===")
        for group in group_types:
            l2 = group_stats[group]['total_l2']
            param = group_stats[group]['total_param']
            avg = (l2 / param) * 1e+5 if param > 0 else 0.0
            print(f"[{group.upper()}] Total L2: {l2:.2f} | "
                  f"Params: {param} | Avg Δ per param(x1e+5): {avg:.6f}")

        # --- Layer index-wise summary ---
        import re
        from collections import defaultdict
        print("\n=== Layer Index-wise Summary ===")
        index_stats = defaultdict(lambda: {'total_l2': 0.0, 'total_param': 0})
        # Regex to extract layer index from e.g. 'transformer.resblocks.3.attn.mask_real'
        import re
        layer_idx_pattern = re.compile(r"resblocks\.(\d+)\.")
        for layer_name, stats in layer_changes.items():
            match = layer_idx_pattern.search(layer_name)
            layer_idx = int(match.group(1)) if match else -1
            index_stats[layer_idx]['total_l2'] += stats['l2_change']
            index_stats[layer_idx]['total_param'] += stats['num_param']

        for idx in sorted(index_stats.keys()):
            l2 = index_stats[idx]['total_l2']
            param = index_stats[idx]['total_param']
            avg = (l2 / param) * 1e+5 if param > 0 else 0.0
            print(f"[Layer {idx:02d}] Total L2: {l2:.2f} | Params: {param} | Avg Δ per param(x1e+5): {avg:.6f}")

        print(f"\n[Overall] Total L2 Change: {total_change:.2f}")
        print(f"[Overall] Avg Change per Param: {total_change / total_param:.6f}")

        return