# import torch
# import torch.nn.functional as F  # 新增的修复行
# import numpy as np
# from sklearn.metrics import roc_auc_score
# from tqdm import tqdm
# from sklearn.svm import SVC
# import utils
# import arg_parser
# import evaluation
# from copy import deepcopy
# import copy
# from collections import OrderedDict

# import matplotlib.pyplot as plt



# def get_attention_masks(model, imgs, patch_size=16, topk_ratio=0.1):
#     """提取ViT最后一层CLS token的注意力权重并生成mask（保持设备一致）"""
#     attentions = []
    
#     # 注册钩子捕获注意力权重（不转移到CPU）
#     def hook_fn(module, input, output):
#         attentions.append(input[0].detach())  # 保留原始设备（GPU）

#     handle = model.blocks[-1].attn.attn_drop.register_forward_hook(hook_fn)
    
#     with torch.no_grad():
#         _ = model(imgs)
#     handle.remove()
    
#     # 修正设备
#     attn_weights = attentions[0].softmax(dim=-1)  # 形状 [B, H, N+1, N+1]
#     cls_attn = attn_weights[:, :, 0, 1:]         # [B, H, N]
#     cls_attn = cls_attn.mean(dim=1)              # [B, N]
    
#     # 生成mask时指定设备与images一致
#     B, N = cls_attn.shape
#     k = int(topk_ratio * N)
#     topk_indices = torch.topk(cls_attn, k, dim=-1).indices
    
#     mask = torch.zeros(B, N, dtype=torch.bool, device=imgs.device)  # 关键修复
#     mask.scatter_(1, topk_indices, True)
#     return mask


# def mask_images(images, masks, patch_size=16, mode='zero'):
#     """根据mask遮蔽图像patch（显式设备管理）"""
#     B, C, H, W = images.shape
#     ph, pw = H // patch_size, W // patch_size
#     device = images.device  # 获取输入图像设备
    
#     # 确保mask在相同设备
#     masks = masks.to(device)
    
#     # 后续操作保持设备一致
#     masks = masks.view(B, ph, pw)
#     masks = masks.unsqueeze(-1).expand(-1, -1, -1, patch_size)
#     masks = masks.unsqueeze(-2).expand(-1, -1, -1, patch_size, -1)
#     masks = masks.permute(0, 1, 3, 2, 4).reshape(B, H, W)
#     masks = masks.unsqueeze(1)  # [B, 1, H, W]
    
#     if mode == 'zero':
#         masked = images * (~masks)
#     elif mode == 'noise':
#         noise = torch.randn_like(images) * 0.5
#         masked = torch.where(masks, noise, images)
#     return masked

# def compute_masked_mia_confidence(model, shadow_train_loader, shadow_test_loader, masked_loader):
#     # 计算遮蔽样本的MIA置信度
#     mia_results = SVC_MIA(
#         shadow_train=shadow_train_loader,  # 影子模型训练集
#         shadow_test=shadow_test_loader,    # 影子模型测试集
#         target_train=None,                 # 目标模型训练集（可选）
#         target_test=masked_loader,         # 遮蔽后的测试集
#         model=model
#     )
#     return mia_results["confidence"]

# def main():
#     args = arg_parser.parse_args()

#     if args.seed:
#         utils.setup_seed(args.seed)
#     seed = args.seed

#     (
#         model,
#         train_loader_full,
#         val_loader,
#         test_loader,
#         marked_loader,
#     ) = utils.setup_model_dataset(args)
#     model.cuda()

#     checkpoint = torch.load(args.model_path, weights_only=False)
#     if "state_dict" in checkpoint.keys():
#         checkpoint = checkpoint["state_dict"]
#     model.load_state_dict(checkpoint, strict=False)
    
#     def replace_loader_dataset(
#         dataset, batch_size=args.batch_size, seed=1, shuffle=True
#     ):
#         utils.setup_seed(seed)
#         return torch.utils.data.DataLoader(
#             dataset,
#             batch_size=batch_size,
#             num_workers=0,
#             pin_memory=True,
#             shuffle=shuffle,
#         )
    
#     forget_dataset = copy.deepcopy(marked_loader.dataset)

#     if args.dataset == "svhn":
#         try:
#             marked = forget_dataset.targets < 0
#         except:
#             marked = forget_dataset.labels < 0
#         forget_dataset.data = forget_dataset.data[marked]
#         try:
#             forget_dataset.targets = -forget_dataset.targets[marked] - 1
#         except:
#             forget_dataset.labels = -forget_dataset.labels[marked] - 1
#         forget_loader = replace_loader_dataset(forget_dataset, seed=seed, shuffle=True)
#         retain_dataset = copy.deepcopy(marked_loader.dataset)
#         try:
#             marked = retain_dataset.targets >= 0
#         except:
#             marked = retain_dataset.labels >= 0
#         retain_dataset.data = retain_dataset.data[marked]
#         try:
#             retain_dataset.targets = retain_dataset.targets[marked]
#         except:
#             retain_dataset.labels = retain_dataset.labels[marked]
#         retain_loader = replace_loader_dataset(retain_dataset, seed=seed, shuffle=True)
#         assert len(forget_dataset) + len(retain_dataset) == len(
#             train_loader_full.dataset
#         )
#     else:
#         try:
#             marked = forget_dataset.targets < 0
#             forget_dataset.data = forget_dataset.data[marked]
#             forget_dataset.targets = -forget_dataset.targets[marked] - 1
#             forget_loader = replace_loader_dataset(
#                 forget_dataset, seed=seed, shuffle=True
#             )
#             retain_dataset = copy.deepcopy(marked_loader.dataset)
#             marked = retain_dataset.targets >= 0
#             retain_dataset.data = retain_dataset.data[marked]
#             retain_dataset.targets = retain_dataset.targets[marked]
#             retain_loader = replace_loader_dataset(
#                 retain_dataset, seed=seed, shuffle=True
#             )
#             assert len(forget_dataset) + len(retain_dataset) == len(
#                 train_loader_full.dataset
#             )
#         except:
#             marked = forget_dataset.targets < 0
#             forget_dataset.imgs = forget_dataset.imgs[marked]
#             forget_dataset.targets = -forget_dataset.targets[marked] - 1
#             forget_loader = replace_loader_dataset(
#                 forget_dataset, seed=seed, shuffle=True
#             )
#             retain_dataset = copy.deepcopy(marked_loader.dataset)
#             marked = retain_dataset.targets >= 0
#             retain_dataset.imgs = retain_dataset.imgs[marked]
#             retain_dataset.targets = retain_dataset.targets[marked]
#             retain_loader = replace_loader_dataset(
#                 retain_dataset, seed=seed, shuffle=True
#             )
#             assert len(forget_dataset) + len(retain_dataset) == len(
#                 train_loader_full.dataset
#             )

#     print(f"number of retain dataset {len(retain_dataset)}")
#     print(f"number of forget dataset {len(forget_dataset)}")
#     unlearn_data_loaders = OrderedDict(
#         retain=retain_loader, forget=forget_loader, val=val_loader, test=test_loader
#     )

#     test_len = len(test_loader.dataset)
#     forget_len = len(forget_dataset)
#     retain_len = len(retain_dataset)

#     utils.dataset_convert_to_test(retain_dataset, args)
#     utils.dataset_convert_to_test(forget_loader, args)
#     utils.dataset_convert_to_test(test_loader, args)

#     shadow_train = torch.utils.data.Subset(retain_dataset, list(range(test_len)))
#     shadow_train_loader = torch.utils.data.DataLoader(
#         shadow_train, batch_size=args.batch_size, shuffle=False
#     )

#     # utils.dataset_convert_to_test(train_loader.dataset, args)

#     # shadow_train = torch.utils.data.Subset(train_loader.dataset, list(range(test_len)))
#     # shadow_train_loader = torch.utils.data.DataLoader(
#     #         shadow_train, batch_size=args.batch_size, shuffle=False
#     #     )

#         # 生成遮蔽样本
#     sample_imgs1, sample_labels1 = next(iter(shadow_train_loader))
#     sample_imgs1 = sample_imgs1.cuda()
#     masks1 = get_attention_masks(model, sample_imgs1, topk_ratio=0.05)
#     high_attn_masked1 = mask_images(sample_imgs1, masks1, mode='noise')
    
#     # 创建遮蔽数据的DataLoader
#     masked_loader1 = torch.utils.data.DataLoader(
#         torch.utils.data.TensorDataset(high_attn_masked1.cpu(), sample_labels1.cpu()), 
#         batch_size=128
#     )


#     sample_imgs2, sample_labels2 = next(iter(forget_loader))
#     sample_imgs2 = sample_imgs1.cuda()
#     masks2 = get_attention_masks(model, sample_imgs2, topk_ratio=0.05)
#     high_attn_masked2 = mask_images(sample_imgs1, masks2, mode='noise')
    
#     # 创建遮蔽数据的DataLoader
#     masked_loader2 = torch.utils.data.DataLoader(
#         torch.utils.data.TensorDataset(high_attn_masked1.cpu(), sample_labels1.cpu()), 
#         batch_size=128
#     )
    
#     # 生成遮蔽样本
#     sample_imgs0, sample_labels0 = next(iter(test_loader))
#     sample_imgs0 = sample_imgs0.cuda()
#     masks0 = get_attention_masks(model, sample_imgs0, topk_ratio=0.05)
#     high_attn_masked0 = mask_images(sample_imgs0, masks0, mode='noise')
    
#     # 创建遮蔽数据的DataLoader
#     masked_loader0 = torch.utils.data.DataLoader(
#         torch.utils.data.TensorDataset(high_attn_masked0.cpu(), sample_labels0.cpu()), 
#         batch_size=128
#     )
    
#     # 分类准确率评估
#     def eval_accuracy(loader):
#         correct = 0
#         for imgs, labels in tqdm(loader, desc='Evaluating Accuracy'):
#             imgs, labels = imgs.cuda(), labels.cuda()
#             with torch.no_grad():
#                 outputs = model(imgs)
#             correct += (outputs.argmax(1) == labels).sum().item()
#         return correct / len(loader.dataset)
    
#     print(f"Original Acc: {eval_accuracy(test_loader):.4f}")
#     print(f"Masked Acc: {eval_accuracy(masked_loader0):.4f}")  # 此处传入masked_loader


#     # original_mia = evaluation.SVC_MIA(
#     #         shadow_train=shadow_train_loader,
#     #         shadow_test=test_loader,
#     #         target_train=None,
#     #         target_test=forget_loader,
#     #         model=model,
#     #     )
#     # masked_mia = evaluation.SVC_MIA(
#     #         shadow_train=shadow_train_loader,
#     #         shadow_test=test_loader,
#     #         target_train=None,
#     #         target_test=masked_loader2,
#     #         model=model,
#     #     )

#     # print(f"原始样本置信度泄漏：{original_mia['confidence']:.4f}")
#     # print(f"遮蔽样本置信度泄漏：{masked_mia['confidence']:.4f}")



# if __name__ == '__main__':
#     main()


import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from sklearn.svm import SVC
import utils
import arg_parser
import evaluation
from copy import deepcopy
import copy
from collections import OrderedDict
import matplotlib.pyplot as plt
from PIL import Image
import os


def get_attention_masks(model, imgs, patch_size=16, topk_ratio=0.1):
    """提取ViT最后一层CLS token的注意力权重并生成mask"""
    attentions = []
    
    # 注册钩子捕获注意力权重
    def hook_fn(module, input, output):
        attentions.append(input[0].detach())  # 保留原始设备信息
    
    handle = model.blocks[-1].attn.attn_drop.register_forward_hook(hook_fn)
    
    with torch.no_grad():
        _ = model(imgs)
    handle.remove()  # 移除钩子避免内存泄漏
    
    # 处理注意力权重
    attn_weights = attentions[0].softmax(dim=-1)  # 形状 [B, H, N+1, N+1]
    cls_attn = attn_weights[:, :, 0, 1:]         # 提取CLS对所有patch的注意力 [B, H, N]
    cls_attn = cls_attn.mean(dim=1)              # 多头平均 [B, N]
    
    # 生成mask
    B, N = cls_attn.shape
    k = int(topk_ratio * N)
    topk_indices = torch.topk(cls_attn, k, dim=-1).indices  # 最高权重的patch索引
    
    mask = torch.zeros(B, N, dtype=torch.bool, device=imgs.device)
    mask.scatter_(1, topk_indices, True)  # 根据索引生成mask
    return mask


def mask_images(images, masks, patch_size=16, mode='noise'):
    """根据注意力mask生成遮蔽图像"""
    B, C, H, W = images.shape
    ph, pw = H // patch_size, W // patch_size  # patch行列数
    device = images.device
    
    # 确保mask与图像在同一设备
    masks = masks.to(device).view(B, ph, pw)
    
    # 扩展mask到像素级别
    masks = masks.unsqueeze(-1).expand(-1, -1, -1, patch_size)
    masks = masks.unsqueeze(-2).expand(-1, -1, -1, patch_size, -1)
    masks = masks.permute(0, 1, 3, 2, 4).reshape(B, H, W).unsqueeze(1)  # [B, 1, H, W]
    
    # 生成遮蔽图像（零值遮蔽或噪声替换）
    if mode == 'zero':
        masked = images * (~masks)
    elif mode == 'noise':
        noise = torch.randn_like(images) * 0.5  # 0.5倍标准差的噪声
        masked = torch.where(masks, noise, images)
    return masked


def save_image(tensor, save_path):
    """保存图像（保持原始尺寸）"""
    # 确保保存目录存在
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # 克隆张量并移至CPU（避免修改原始数据）
    img_tensor = tensor.clone().cpu()
    
    # 处理通道优先格式 [C, H, W]
    if img_tensor.ndim == 3:
        img_np = img_tensor.numpy().transpose(1, 2, 0)  # 转为[H, W, C]
    else:
        img_np = img_tensor.numpy()[0].transpose(1, 2, 0)  # 处理批量维度
    
    # 归一化到[0, 1]区间
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
    
    # 转换为0-255的uint8格式并保存
    img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
    img_pil.save(save_path)
    print(f"已保存: {save_path}")


def compute_masked_mia_confidence(model, shadow_train_loader, shadow_test_loader, masked_loader):
    """计算遮蔽样本的成员推理攻击（MIA）置信度"""
    mia_results = evaluation.SVC_MIA(
        shadow_train=shadow_train_loader,
        shadow_test=shadow_test_loader,
        target_train=None,
        target_test=masked_loader,
        model=model
    )
    return mia_results["confidence"]


def main():
    # 解析命令行参数
    args = arg_parser.parse_args()
    
    # 设置随机种子（如果指定）
    if args.seed:
        utils.setup_seed(args.seed)
    seed = args.seed
    
    # 加载模型和数据集
    (model, train_loader_full, val_loader, test_loader, marked_loader) = utils.setup_model_dataset(args)
    model = model.cuda()  # 将模型移至GPU
    
    # 加载预训练模型权重
    checkpoint = torch.load(args.model_path, weights_only=False)
    if "state_dict" in checkpoint.keys():
        checkpoint = checkpoint["state_dict"]
    model.load_state_dict(checkpoint, strict=False)
    print("模型加载完成")
    
    # 数据集分割辅助函数
    def replace_loader_dataset(dataset, batch_size=args.batch_size, seed=1, shuffle=True):
        utils.setup_seed(seed)
        return torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=shuffle
        )
    
    # 处理遗忘数据集和保留数据集（以SVHN为例，其他数据集类似）
    forget_dataset = copy.deepcopy(marked_loader.dataset)
    if args.dataset == "svhn":
        try:
            marked = forget_dataset.targets < 0
        except:
            marked = forget_dataset.labels < 0
        forget_dataset.data = forget_dataset.data[marked]
        try:
            forget_dataset.targets = -forget_dataset.targets[marked] - 1
        except:
            forget_dataset.labels = -forget_dataset.labels[marked] - 1
        forget_loader = replace_loader_dataset(forget_dataset, seed=seed, shuffle=True)
        
        retain_dataset = copy.deepcopy(marked_loader.dataset)
        try:
            marked = retain_dataset.targets >= 0
        except:
            marked = retain_dataset.labels >= 0
        retain_dataset.data = retain_dataset.data[marked]
        try:
            retain_dataset.targets = retain_dataset.targets[marked]
        except:
            retain_dataset.labels = retain_dataset.labels[marked]
        retain_loader = replace_loader_dataset(retain_dataset, seed=seed, shuffle=True)
        assert len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset)
    else:
        # 处理其他数据集（如CIFAR-10/100）
        try:
            marked = forget_dataset.targets < 0
            forget_dataset.data = forget_dataset.data[marked]
            forget_dataset.targets = -forget_dataset.targets[marked] - 1
            forget_loader = replace_loader_dataset(forget_dataset, seed=seed, shuffle=True)
            
            retain_dataset = copy.deepcopy(marked_loader.dataset)
            marked = retain_dataset.targets >= 0
            retain_dataset.data = retain_dataset.data[marked]
            retain_dataset.targets = retain_dataset.targets[marked]
            retain_loader = replace_loader_dataset(retain_dataset, seed=seed, shuffle=True)
            assert len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset)
        except:
            marked = forget_dataset.targets < 0
            forget_dataset.imgs = forget_dataset.imgs[marked]
            forget_dataset.targets = -forget_dataset.targets[marked] - 1
            forget_loader = replace_loader_dataset(forget_dataset, seed=seed, shuffle=True)
            
            retain_dataset = copy.deepcopy(marked_loader.dataset)
            marked = retain_dataset.targets >= 0
            retain_dataset.imgs = retain_dataset.imgs[marked]
            retain_dataset.targets = retain_dataset.targets[marked]
            retain_loader = replace_loader_dataset(retain_dataset, seed=seed, shuffle=True)
            assert len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset)
    
    # 打印数据集统计信息
    print(f"保留数据集大小: {len(retain_dataset)}")
    print(f"遗忘数据集大小: {len(forget_dataset)}")
    unlearn_data_loaders = OrderedDict(
        retain=retain_loader, forget=forget_loader, val=val_loader, test=test_loader
    )
    
    # 准备影子训练集（从保留数据集中采样）
    test_len = len(test_loader.dataset)
    shadow_train = torch.utils.data.Subset(retain_dataset, list(range(test_len)))
    shadow_train_loader = torch.utils.data.DataLoader(
        shadow_train, batch_size=args.batch_size, shuffle=False
    )
    
    # 创建图像保存目录
    os.makedirs("./images/original", exist_ok=True)
    os.makedirs("./images/masked", exist_ok=True)
    
    # 遍历测试集处理所有图像
    model.eval()
    with torch.no_grad():
        for batch_idx, (imgs, labels) in enumerate(tqdm(test_loader, desc="处理测试集图像")):
            imgs = imgs.cuda()  # 移至GPU
            batch_size = imgs.size(0)
            
            for idx in range(batch_size):
                # 计算全局样本索引
                global_idx = batch_idx * args.batch_size + idx
                if global_idx >= len(test_loader.dataset):
                    continue
                
                # 1. 保存原始图像
                label = labels[idx].item()
                original_path = f"./images/original/test_{global_idx}_label{label}.png"
                save_image(imgs[idx], original_path)
                
                # 2. 生成注意力mask并创建遮蔽图像
                single_img = imgs[idx].unsqueeze(0)  # 添加批次维度
                mask = get_attention_masks(model, single_img, topk_ratio=0.05)
                masked_img = mask_images(single_img, mask, mode='zero')
                
                # 3. 保存遮蔽图像
                masked_path = f"./images/masked/test_{global_idx}_label{label}_masked.png"
                save_image(masked_img[0], masked_path)
    
    # 评估原始图像和遮蔽图像的分类准确率
    def eval_accuracy(loader):
        correct = 0
        total = 0
        for imgs, labels in tqdm(loader, desc="评估准确率"):
            imgs, labels = imgs.cuda(), labels.cuda()
            outputs = model(imgs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        return correct / total
    
    print(f"原始图像测试集准确率: {eval_accuracy(test_loader):.4f}")
    
    # # 成员推理攻击评估（可选，根据需求取消注释）
    # original_mia = evaluation.SVC_MIA(
    #     shadow_train=shadow_train_loader,
    #     shadow_test=test_loader,
    #     target_train=None,
    #     target_test=forget_loader,
    #     model=model,
    # )
    # print(f"原始样本MIA置信度: {original_mia['confidence']:.4f}")


if __name__ == '__main__':
    main()
