## Code to train and evaluate a ResNet model with our CGC method on ImageNet/CUB-200/Cars-196/FGVC-Aircraft/VGG Flowers-102 datasets.
## This code is adapated from https://github.com/pytorch/examples/blob/master/imagenet/main.py

import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from datasets.imagefolder_lima_ssl import ImageFolder
import models.resnet_multigpu_cgc as resnet
from transformers import ViTForImageClassification
import torch.nn.functional as F

from models.utils import compute_gradcam_mask_vit, perform_gradcam_aug, eclip
import numpy as np

import logging

from train_eval_cgc import save_dir
from PIL import Image
import cv2

from LIMA.models.submodular_single_modal import BlackBoxSingleModalSubModularExplanationQuicktest


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


model_names = ['resnet18' , 'resnet50']
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data_dir', default='/mnt/huawei/jiaoxh/data/ImageNet100', help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='l16_224', choices=model_names,
                                                help='model architecture: ' +
                                                    ' | '.join(model_names) +
                                                    ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=32, type=int,
                                        metavar='N',
                                        help='mini-batch size (default: 256), this is the total '
                                             'batch size of all GPUs on the current node when '
                                             'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=5e-5, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')

parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

parser.add_argument('--save_dir', default='checkpoint/lima/resume', type=str, metavar='SV_PATH', help='path to save checkpoints (default: none)')
parser.add_argument('--log_dir', default='checkpoint/bs_256-res18/imagenet100-cgc-train-epoch60/logs', type=str, metavar='LG_PATH', help='path to write logs (default: logs)')
parser.add_argument('--dataset', type=str, default='imagenet100', help='dataset to use: [imagenet, tiny_imagenet]')

parser.add_argument('-t', type=float, default=0.5)
parser.add_argument('--lambda', default=0.5, type=float, metavar='LAM', help='lambda hyperparameter for GCAM loss', dest='lambda_val')
parser.add_argument('--reg-freq', default=20, type=int, metavar='RF', help='print frequency (default: 10)')

parser.add_argument('--mode', type=str, default='normal', help='path to dataset')
parser.add_argument('--attribution', type=str, default='eclip', help='path to dataset')
parser.add_argument('--divide_version', type=int, default=51, help='path to dataset')

args = parser.parse_args()

# divide_v1: whole vit
# divide_v2: embedding + encoder
# divide_v3: encoder
# divide_v4: embedding
# divide_v5_1: first 3 encoder layers
# divide_v5_4: last 3 encoder layers


best_acc1 = 0
if args.dataset == 'imagenet100':
    save_dir = args.save_dir + '/' + args.arch + '/'
else:
    save_dir = args.save_dir + '/' + args.dataset + '/'
if args.mode == 'divide':
    save_dir += 'divide_v{}/'.format(args.divide_version)
total_epoch = args.epochs
# save_dir = save_dir+'0908_2-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'two_loss-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'three_loss-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)

# save_dir = save_dir+'conf+cons-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
save_dir = save_dir+'colla+cons-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'colla+conf-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'colla-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'conf-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'cons-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'full-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)

# save_dir = save_dir+'ab-lambda/epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)

# save_dir = save_dir+'three_loss-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)
# save_dir = save_dir+'two_loss-epoch{}-lambda{}-reg_freq{}'.format(args.epochs, args.lambda_val, args.reg_freq)

def main():
    print(save_dir)
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(save_dir+'/logs', exist_ok=True)
    logger = get_logger(logpath=os.path.join(save_dir+'/logs', 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args, logger)


class NCESoftmaxLoss(nn.Module):
    """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)"""
    def __init__(self , T=0.01):
        super(NCESoftmaxLoss, self).__init__()
        self.T = T
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        bsz = x.shape[0]
        x = x.squeeze()
        x = torch.div(x, self.T)
        label = torch.zeros([bsz]).cuda().long()
        loss = self.criterion(x, label)
        return loss


class PatchLogitLoss(nn.Module):
    def __init__(self, alpha=0.0):
        super(PatchLogitLoss, self).__init__()
        self.alpha = alpha

    def forward(self, selected_patch_logits, other_logits, t):
        """
        logits: Tensor of shape (64, num_classes), where 64 is the number of patches
        k: Index of the target patch
        t: Target class label
        """
        # Extract the logits for the target patch and other patches
        mask = torch.zeros(100)
        mask[t] = 1.0
        # logit_k_t = logits[k, t]
        # logit_others_t = logits[:, t]
        #
        # # Compute the loss
        # loss = 0
        # for i in range(logits.size(0)):
        #     if i != k:
        #         loss += torch.relu(self.alpha + logit_others_t[i] - logit_k_t)

        return torch.relu((self.alpha + other_logits - selected_patch_logits) * mask) / other_logits.shape[0]


def normalize(x):
    return x / x.norm(2, dim=1, keepdim=True)


def generate_masks_np(shape, patch_size):
    # 获取图像的宽度和高度
    height, width, channels = shape

    # 计算每个小块的宽度和高度
    patch_width = width // patch_size
    patch_height = height // patch_size

    # 初始化一个数组来存储所有掩码
    masks = np.zeros((patch_size * patch_size, height, width, channels), dtype=np.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = np.zeros((height, width, channels), dtype=np.uint8)
            mask[start_y:end_y, start_x:end_x, :] = 1

            # 将掩码存储到 masks 数组中
            masks[i * patch_size + j] = mask

    return masks


def generate_masks_torch(patch_size):
    """
    Generate masks for patches in a given image shape.

    Args:
    shape: Tuple (height, width, channels) representing the shape of the image.
    patch_size: The number of patches along one dimension (assuming square patches).

    Returns:
    masks: Tensor of shape (patch_size * patch_size, height, width, channels) containing the masks.
    """
    # 计算每个小块的宽度和高度
    patch_width = 224 // patch_size
    patch_height = 224 // patch_size

    # 初始化一个张量来存储所有掩码
    masks = torch.zeros((patch_size * patch_size, 3, 224, 224), dtype=torch.uint8)

    # 生成掩码
    for i in range(patch_size):
        for j in range(patch_size):
            # 计算当前小块的起始和结束位置
            start_x = j * patch_width
            start_y = i * patch_height
            end_x = start_x + patch_width
            end_y = start_y + patch_height

            # 创建掩码
            mask = torch.zeros((3, 224, 224), dtype=torch.uint8)
            mask[:, start_y:end_y, start_x:end_x] = 1

            # 将掩码存储到 masks 张量中
            masks[i * patch_size + j] = mask

    return masks

patches_ndarray = generate_masks_np((224, 224, 3), 6)
patches_tensor = generate_masks_torch(6).cuda()

def preprocess(image):
    image = Image.fromarray(image, 'RGB')
    transform = transforms.Compose([
        # transforms.Resize(256),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                      std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).cuda()


def compute_smdl_score(sub_images, org_img, model, target_filter):
    replace_values = torch.tensor([-2.1179, -2.0357, -1.8044], dtype=torch.float32)

    sub_images_reverse = org_img - sub_images

    for c in range(3):
        zero_indices = sub_images[:, c, :, :] == 0
        sub_images[:, c, :, :][zero_indices] = replace_values[c]

    # 1. Consistency Score
    predicted_scores = torch.softmax(model(sub_images).logits, dim=-1)
    score_consistency = torch.sum((predicted_scores * target_filter), dim=-1)

    # for c in range(3):
    #     zero_indices = selected_region[:, c, :, :] == 0
    #     selected_region[:, c, :, :][zero_indices] = replace_values[c]

    # 3. Confidence Score
    entropy = - torch.sum(predicted_scores * torch.log(predicted_scores + 1e-7), dim=1)
    max_entropy = torch.log(torch.tensor(predicted_scores.shape[1])).cuda()
    score_confidence = 1 - (entropy / max_entropy)

    # 2. Collaboration Score
    for c in range(3):
        zero_indices = sub_images_reverse[:, c, :, :] == 0
        sub_images_reverse[:, c, :, :][zero_indices] = replace_values[c]

    predicted_scores_reverse = torch.softmax(model(sub_images_reverse).logits, dim=-1)
    score_collaboration = (1 - predicted_scores_reverse * target_filter).sum()
    # score_collaboration = torch.sum(1 - predicted_scores_reverse * target_filter, dim=-1)

    # submodular score
    # smdl_score = score_consistency + score_collaboration + score_confidence
    # smdl_score = score_consistency + score_collaboration

    # smdl_score = score_consistency + score_confidence
    smdl_score = score_consistency + score_collaboration
    # smdl_score = score_confidence + score_collaboration
    # smdl_score = score_collaboration
    # smdl_score = score_confidence
    # smdl_score = score_consistency
    return smdl_score


def main_worker(gpu, ngpus_per_node, args, logger):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        logger.info("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    kwargs = {}
    num_classes = 1000
    val_dir_name = 'val'
    if args.dataset == 'tiny_imagenet':
        kwargs = {'num_classes': 200}
        num_classes = 200
    elif args.dataset == 'imagenet100':
        kwargs = {'num_classes': 100}
        num_classes = 100
    elif args.dataset == 'cub':
        kwargs = {'num_classes': 200}
        num_classes = 200
        # val_dir_name = 'test'
    elif args.dataset == 'dogs':
        kwargs = {'num_classes': 120}
        num_classes = 120
    elif args.dataset == 'fgvc':
        kwargs = {'num_classes': 100}
        num_classes = 100
    elif args.dataset == 'flowers':
        kwargs = {'num_classes': 102}
        num_classes = 102
    elif args.dataset == 'cars':
        kwargs = {'num_classes': 196}
        num_classes = 196

    if args.dataset == 'cub':
        model = ViTForImageClassification.from_pretrained('pretrained_model',
                                                          subfolder='vit-base-patch16-224-cub',
                                                          ignore_mismatched_sizes=True)

    elif args.arch == 'b16_224':
        model = ViTForImageClassification.from_pretrained('pretrained_model',
                                                          subfolder='vit-base-patch16-224', ignore_mismatched_sizes=True)
        logger.info("=> creating model 'b16_224'")
    elif args.arch == 'l16_224':
        model = ViTForImageClassification.from_pretrained('pretrained_model',
                                                          subfolder='vit-large-patch16-224', ignore_mismatched_sizes=True)
        logger.info("=> creating model 'l16_224'")
    elif args.arch == 's16_224':
        model = ViTForImageClassification.from_pretrained('google/vit-small-patch16-224')
        logger.info("=> creating model 's16_224'")
    elif args.arch == 't16_224':
        model = ViTForImageClassification.from_pretrained('google/vit-tiny-patch16-224')
        logger.info("=> creating model 's16_224'")

    model = torch.nn.DataParallel(model)
    # optionally resume from a checkpoint
    start_epoch = args.start_epoch

    # model.module.classifier = nn.Linear(768, num_classes)
    # model.num_labels = num_classes
    # if args.dataset == 'cub':
    #     state_dict = torch.load('checkpoint/cub/b16_224/eclip-epoch10-baseline/model_best.pth.tar')['state_dict']

    # state_dict = torch.load('checkpoint/b16_224/epoch10-baseline/model_best.pth.tar')['state_dict']
    state_dict = torch.load('checkpoint/l16_224/eclip-epoch10-baseline-first/model_best.pth.tar')['state_dict']
    # state_dict = torch.load('checkpoint/lima/b16_224/0902-epoch10-lambda0.5-reg_freq20/checkpoint_001.pth.tar')['state_dict']
    # state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model = model.cuda()

    # for param in model.module.classifier.parameters():
    #     param.requires_grad = False

    logger.info(model)

    # define loss function (criterion) and optimizer
    xent_criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    contrastive_criterion = NCESoftmaxLoss(args.t).cuda(args.gpu)

    cudnn.benchmark = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    if args.mode == 'divide':
        optimizer_fc = torch.optim.AdamW(model.module.classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.step(start_epoch)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data_dir, 'train')
    valdir = os.path.join(args.data_dir, val_dir_name)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImageFolder(traindir)   # transforms are handled within the implementation

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    # (train_sampler is None)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_batch_size = args.batch_size

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=val_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, contrastive_criterion, xent_criterion, args, logger)
        return

    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if args.mode == 'normal':
            train(train_loader, model, contrastive_criterion, xent_criterion, optimizer, epoch, args, logger, num_classes)
        else:
            train(train_loader, model, contrastive_criterion, xent_criterion, (optimizer, optimizer_fc), epoch, args, logger)

        scheduler.step()

        # evaluate on validation set
        acc1 = validate(val_loader, model, contrastive_criterion, xent_criterion, args, logger)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, save_dir)


def train(train_loader, model, contrastive_criterion, xent_criterion, optimizer, epoch, args, logger, num_classes):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    xe_losses = AverageMeter('XE Loss', ':.4e')
    gc_losses = AverageMeter('GC Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, xe_losses, gc_losses, losses, top1, top5],
        logger,
        prefix="Epoch: [{}]".format(epoch))

    if args.mode == 'divide':
        optimizer, optimizer_fc = optimizer

    relu = nn.ReLU(inplace=True)

    # switch to train mode
    model.train()
    train_len = len(train_loader)
    train_iter = iter(train_loader)
    end = time.time()
    for i in range(train_len):
        (xe_images , images, aug_images,
         transforms_i, transforms_j, transforms_h, transforms_w, hor_flip,
         targets, paths) = train_iter.__next__()

        # measure data loading time
        data_time.update(time.time() - end)

        images = images.cuda(args.gpu, non_blocking=True)
        aug_images = aug_images.cuda(args.gpu, non_blocking=True)
        targets = targets.cuda(args.gpu, non_blocking=True)
        aug_params_dict = {'transforms_i': transforms_i, 'transforms_j': transforms_j, 'transforms_h': transforms_h,
                            'transforms_w': transforms_w, 'hor_flip': hor_flip}

        # aug_output, xe_loss, contrastive_loss = model(images, contrastive_criterion, xe_images=xe_images, aug_images=aug_images,
        #                                               aug_params_dict=aug_params_dict, targets=targets, xent_criterion=xent_criterion)

        outputs = model(pixel_values=images, labels=targets, output_hidden_states=True if args.lambda_val > 0.0 else False)
        xe_outputs = model(pixel_values=xe_images, labels=targets, output_hidden_states=True if args.lambda_val > 0.0 else False)
        aug_outputs = model(pixel_values=aug_images, labels=targets, output_hidden_states=True if args.lambda_val > 0.0 else False)

        xe_loss = outputs.loss + xe_outputs.loss + aug_outputs.loss
        # rank_loss = torch.zeros(1).cuda()
        rank_loss = 0


        if i % args.reg_freq == 0:
            confidence_threshold = 0.8
            confidences, predicted = F.softmax(outputs.logits, dim=1).max(dim=1)
            correct = (predicted == targets)
            # 找出置信度超过阈值的正确分类的样本
            high_confidence_correct = (confidences > confidence_threshold) & correct
            # 获取这些样本的序号
            high_confidence_correct_indices = torch.nonzero(high_confidence_correct).squeeze(1).cpu().numpy()

            lima_region_sequences = []
            smdl = BlackBoxSingleModalSubModularExplanationQuicktest(
                model,
                preprocess,
                k=50,
                lambda1=1,
                lambda2=1,
                lambda3=1,
                trunc_iter=8)
            for index in high_confidence_correct_indices:
                path = paths[index]
                image = cv2.imread(path)
                # image = cv2.resize(image, (256, 256))
                # image = image[16:240, 16:240]
                image = cv2.resize(image, (224, 224))

                # element_sets_V = SubRegionDivision(image, mode="seeds")
                # element_sets_V = patches_ndarray * image

                image_tensor = preprocess(image)
                patches_tensor = generate_masks_torch(6).cuda()
                element_sets_V = patches_tensor * image_tensor

                smdl.k = len(element_sets_V)
                submodular_image, submodular_image_set, saved_json_file = smdl(element_sets_V)

                lima_region_sequences.append(saved_json_file['sequence'].tolist() if saved_json_file else None)

            batch_element_sets_flip = torch.flip(images.unsqueeze(1) * patches_tensor.unsqueeze(0), dims=[4])

            count = 1e-7
            for j, sequence in zip(high_confidence_correct_indices, lima_region_sequences):
                if sequence is None:
                    continue
                element_sets_flip = batch_element_sets_flip[j]
                mask = torch.zeros(num_classes).cuda()
                mask[targets[j]] = 1.0
                searched_element_set = []
                image_flip = element_sets_flip.sum(0).unsqueeze(0)
                for element_index in sequence:
                    accumulated_element_set_flip = element_sets_flip + element_sets_flip[searched_element_set, :].sum(0).unsqueeze(0)

                    searched_element_set.append(element_index)

                    flag = torch.ones(smdl.k, dtype=torch.bool)
                    flag[searched_element_set] = False

                    selected_region = accumulated_element_set_flip[element_index].unsqueeze(0)
                    other_region = accumulated_element_set_flip[flag]

                    # replace_values = torch.tensor([-2.1179, -2.0357, -1.8044], dtype=torch.float32)
                    # for c in range(3):
                    #     zero_indices = selected_region[:, c, :, :] == 0
                    #     selected_region[:, c, :, :][zero_indices] = replace_values[c]
                    #
                    # selected_patch_logits = torch.softmax(model(selected_region).logits, dim=-1)
                    #
                    # if element_index == sequence[-1]:
                    #     rank_loss += torch.relu((0.7 - selected_patch_logits) * mask).sum()
                    #     count += 1
                    # else:
                    #     # print(count)
                    #     for c in range(3):
                    #         zero_indices = other_region[:, c, :, :] == 0
                    #         other_region[:, c, :, :][zero_indices] = replace_values[c]
                    #
                    #     with torch.no_grad():
                    #         # values, indices = torch.topk(model(other_region).logits[:, targets[j]], 20)
                    #         indices = torch.softmax(model(other_region).logits, dim=-1)[:, targets[j]] > selected_patch_logits[:, targets[j]]
                    #
                    #     other_logits = torch.softmax(model(other_region[indices]).logits, dim=-1)
                    #     # other_logits = model(other_region).logits

                    if element_index == sequence[-1]:
                        replace_values = torch.tensor([-2.1179, -2.0357, -1.8044], dtype=torch.float32)
                        for c in range(3):
                            zero_indices = selected_region[:, c, :, :] == 0
                            selected_region[:, c, :, :][zero_indices] = replace_values[c]

                        selected_patch_logits = torch.softmax(model(selected_region).logits, dim=-1)
                        # rank_loss += torch.relu((0.7 - selected_patch_logits) * mask).sum()
                        r_loss = torch.relu((0.7 - selected_patch_logits) * mask).sum()
                        rank_loss += r_loss.item()
                        r_loss.backward()
                        count += 1
                    else:
                        # print(count)
                        selected_smdl_score = compute_smdl_score(selected_region, image_flip, model, mask)
                        with torch.no_grad():
                            other_smdl_score = compute_smdl_score(other_region, image_flip, model, mask)
                            indices = other_smdl_score > selected_smdl_score
                        other_smdl_score = compute_smdl_score(other_region[indices], image_flip, model, mask)

                        # rank_loss += torch.relu(other_smdl_score - selected_smdl_score).sum()
                        r_loss = torch.relu(other_smdl_score - selected_smdl_score).sum()
                        rank_loss += r_loss.item()

                        count += other_smdl_score.shape[0]
                        r_loss.backward()

                #         if cou

                #     print(rank_loss/count)
                #     count = 1e-7
                # exit()

            rank_loss /= count
            # print('Rank Loss: {}'.format(rank_loss)+', use {} samples'.format(len([item for item in lima_region_sequences if item is not None])))
            logger.info('Rank Loss: {}'.format(rank_loss)+', use {} samples'.format(len([item for item in lima_region_sequences if item is not None])))
        for p in model.parameters():
            p.grad /= count



        # loss = xe_loss + args.lambda_val * rank_loss
        loss = xe_loss

        # measure accuracy and record loss
        acc1, acc5 = accuracy(aug_outputs.logits, targets, topk=(1, 5))

        losses.update(loss.item(), images.size(0))
        xe_losses.update(xe_loss.item(), images.size(0))
        gc_losses.update(rank_loss, images.size(0))

        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

        if i % 800 == 0:
            save_checkpoint({
                'epoch': epoch*4000 + i,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, False, save_dir)


def validate(val_loader, model, contrastive_criterion, criterion, args, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        logger,
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, targets) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            targets = targets.cuda(args.gpu, non_blocking=True)

            # compute output
            # output = model(images,contrastive_criterion, vanilla=True)
            # loss = criterion(output, targets)
            outputs = model(pixel_values=images, labels=targets)
            loss = outputs.loss
            output = outputs.logits

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, targets, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, save_dir):
    epoch = state['epoch']
    filename = 'checkpoint_' + str(epoch).zfill(3) + '.pth.tar'
    save_path = os.path.join(save_dir, filename)
    torch.save(state, save_path)
    if is_best:
        best_filename = 'model_best.pth.tar'
        best_save_path = os.path.join(save_dir, best_filename)
        shutil.copyfile(save_path, best_save_path)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters,logger, prefix="" ):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix
        self.logger = logger

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]

        self.logger.info('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()

