import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from data_utils import *
from attack_baseline import *
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms.functional import rotate
import random
import torch
import torch.nn as nn
import cv2
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

def evaluate_cifar(model, processor, data_loader, backdoor_indices, description="Evaluating"):
    model.eval()  # 切换到评估模式
    correct_clean = 0
    total_clean = 0
    correct_backdoor = 0
    total_backdoor = 0

    with torch.no_grad():  # 关闭梯度计算
        for i, (images, labels) in enumerate(tqdm(data_loader, desc=description)):
            # 使用 processor 进行预处理，生成 pixel_values 并确保数据在 model.device 上
            inputs = processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to('cuda')

            # 使用 get_image_features 获取图像特征
            image_features = model.get_image_features(pixel_values=inputs)
            outputs = model.classifier(image_features)

            # 获取预测值
            _, predicted = torch.max(outputs, 1)

            # 将标签移动到与模型相同的设备上
            labels = labels.to(model.device)

            # 检查当前批次中的每个样本是否属于后门数据集
            for j, (pred, label) in enumerate(zip(predicted, labels)):
                index = i * data_loader.batch_size + j
                if index in backdoor_indices:
                    total_backdoor += 1
                    if pred == label:
                        correct_backdoor += 1
                else:
                    total_clean += 1
                    if pred == label:
                        correct_clean += 1

    accuracy_clean = correct_clean / total_clean if total_clean > 0 else 0
    accuracy_backdoor = correct_backdoor / total_backdoor if total_backdoor > 0 else 0

    return accuracy_clean, accuracy_backdoor
def evaluate_imagenet(model, processor, data_loader, backdoor_indices, description="Evaluating"):
    model.eval()  # 切换到评估模式
    correct_clean = 0
    total_clean = 0
    correct_backdoor = 0
    total_backdoor = 0

    with torch.no_grad():  # 关闭梯度计算
        for i, (images, labels) in enumerate(tqdm(data_loader, desc=description)):
            # 使用 processor 进行预处理，生成 pixel_values 并确保数据在 model.device 上
            inputs = processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to('cuda')

            # 使用 get_image_features 获取图像特征
            image_features = model.get_image_features(pixel_values=inputs)
            outputs = model.classifier(image_features)

            # 获取预测值
            _, predicted = torch.max(outputs, 1)

            # 将标签移动到与模型相同的设备上
            labels = labels.to(model.device)

            # 检查当前批次中的每个样本是否属于后门数据集
            for j, (pred, label) in enumerate(zip(predicted, labels)):
                index = i * data_loader.batch_size + j
                if index in backdoor_indices:
                    total_backdoor += 1
                    if pred == label:
                        correct_backdoor += 1
                else:
                    total_clean += 1
                    if pred == label:
                        correct_clean += 1

    accuracy_clean = correct_clean / total_clean if total_clean > 0 else 0
    accuracy_backdoor = correct_backdoor / total_backdoor if total_backdoor > 0 else 0

    return accuracy_clean, accuracy_backdoor

def load_and_poison_data(args):
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # 根据args.dataset加载对应的数据集
    if args.dataset == "CIFAR-10":
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif args.dataset == "CIFAR-100":
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    elif args.dataset == "imagenet-tiny":
        trainset, testset = load_tiny_imagenet_data()
    else:
        raise ValueError("Unsupported dataset")

    # 创建 Hello Kitty 触发器
    def create_hello_kitty_trigger_64():
        # 定义颜色
        WHITE = [255, 255, 255]
        BLACK = [0, 0, 0]
        YELLOW = [255, 255, 0]
        PINK = [255, 182, 193]

        # 创建 64x64x3 的空白图像
        img = np.zeros((64, 64, 3), dtype=np.uint8)

        # 填充白色背景（脸）
        img[:, :] = WHITE

        # 添加黑色轮廓
        img[0, 10:54] = BLACK
        img[1:3, 4:60] = BLACK
        img[3:8, 2:62] = BLACK
        img[8:56, 0:64] = BLACK
        img[56:60, 2:62] = BLACK
        img[60:62, 4:60] = BLACK
        img[62:64, 10:54] = BLACK

        # 添加眼睛
        img[20:28, 16:24] = BLACK
        img[20:28, 40:48] = BLACK

        # 添加鼻子
        img[32:40, 28:36] = YELLOW

        # 添加粉色蝴蝶结
        img[8:16, 26:38] = PINK

        # 转换为 PIL Image 然后转为 tensor
        pil_img = Image.fromarray(img)
        return transforms.ToTensor()(pil_img)
    def create_hello_kitty_trigger():
        # 定义颜色
        WHITE = [255, 255, 255]
        BLACK = [0, 0, 0]
        YELLOW = [255, 255, 0]
        PINK = [255, 182, 193]

        # 创建 32x32x3 的空白图像
        img = np.zeros((32, 32, 3), dtype=np.uint8)

        # 填充白色背景（脸）
        img[:, :] = WHITE

        # 添加黑色轮廓
        img[0, 5:27] = BLACK
        img[1, 2:30] = BLACK
        img[2:4, 1:31] = BLACK
        img[4:28, 0:32] = BLACK
        img[28:30, 1:31] = BLACK
        img[30, 2:30] = BLACK
        img[31, 5:27] = BLACK

        # 添加眼睛
        img[10:14, 8:12] = BLACK
        img[10:14, 20:24] = BLACK

        # 添加鼻子
        img[16:20, 14:18] = YELLOW

        # 添加粉色蝴蝶结
        img[4:8, 13:19] = PINK

        # 转换为 PIL Image 然后转为 tensor
        pil_img = Image.fromarray(img)
        return transforms.ToTensor()(pil_img)

    hello_kitty_trigger = create_hello_kitty_trigger()
    hello_kitty_trigger_64 = create_hello_kitty_trigger_64()
    # 创建触发器模式
    # trigger = hello_kitty_trigger# triggrt是hello kitty
    def create_trigger(img, attack_type):
        def create_wanet_grids(img_shape, s=0.5, grid_rescale=1):
            h, w = img_shape
            n = 30  # 可以根据需要调整
            x = torch.linspace(-1, 1, steps=w)
            y = torch.linspace(-1, 1, steps=h)
            xx, yy = torch.meshgrid(x, y, indexing='ij')
            identity_grid = torch.stack([xx, yy], dim=-1).unsqueeze(0)

            x = torch.linspace(-1, 1, steps=n)
            y = torch.linspace(-1, 1, steps=n)
            xx, yy = torch.meshgrid(x, y, indexing='ij')
            noise_grid = torch.stack([xx, yy], dim=-1).unsqueeze(0)

            # 调整 noise_grid 到与 identity_grid 相同的尺寸
            noise_grid = F.interpolate(noise_grid.permute(0, 3, 1, 2), size=(h, w), mode='bilinear', align_corners=True)
            noise_grid = noise_grid.permute(0, 2, 3, 1)

            grid = identity_grid + s * noise_grid / h
            return torch.clamp(grid * grid_rescale, -1, 1)

        if attack_type == "Refool":
            # 确保img是torch.Tensor
            is_tensor = isinstance(img, torch.Tensor)
            if is_tensor:
                # 如果是tensor，转换为numpy进行处理
                img_np = img.numpy()
            else:
                img_np = img

            # 确保图像格式是[H, W, C]
            if img_np.shape[0] == 1 or img_np.shape[0] == 3:
                img_np = np.transpose(img_np, (1, 2, 0))

            # 转换为PIL Image进行处理
            img_pil = Image.fromarray((img_np * 255).astype(np.uint8))

            # 获取反射图像
            reflection_index = np.random.randint(0, len(args.reflection_candidates))
            reflection = args.reflection_candidates[reflection_index]

            # 调整反射图像大小以匹配输入图像
            reflection = reflection.resize(img_pil.size)

            # 转换为numpy数组
            img_np = np.array(img_pil).astype(float) / 255.0
            reflection_np = np.array(reflection).astype(float) / 255.0

            # 应用反射效果
            alpha_b = np.random.uniform(0.55, 0.95) if args.alpha_b < 0 else args.alpha_b
            blended = img_np * alpha_b + reflection_np * (1 - alpha_b)

            # 应用高斯模糊(如果需要)
            if np.random.rand() < args.ghost_rate:
                sigma = np.random.uniform(1, 5) if args.sigma < 0 else args.sigma
                blended = cv2.GaussianBlur(blended, (0, 0), sigma)

            # 裁剪到[0, 1]范围
            poisoned_img = np.clip(blended, 0, 1)

            # 转换回PyTorch tensor
            poisoned_img = torch.from_numpy(poisoned_img.transpose(2, 0, 1)).float()

            return poisoned_img

        elif attack_type == "WaNet":
            # 确保img是tensor
            if not isinstance(img, torch.Tensor):
                img = torch.tensor(img).float()

            # 确保img是4D tensor [1, 3, H, W]
            if img.dim() == 3:
                img = img.unsqueeze(0)

            # 创建WaNet网格
            grid = create_wanet_grids(img.shape[-2:])

            # 应用网格采样
            poisoned_img = F.grid_sample(img, grid, align_corners=True).squeeze(0)

            # 如果原始输入不是tensor，转回numpy
            if not isinstance(img, torch.Tensor):
                return poisoned_img.numpy()
            else:
                return poisoned_img
        elif attack_type == "ISSBA":
            # 确保img是tensor
            if not isinstance(img, torch.Tensor):
                img = torch.tensor(img).float()
            img = img.to('cuda')
            # 确保img是4D tensor [1, 3, H, W]
            if img.dim() == 3:
                img = img.unsqueeze(0)

            # 定义Encoder
            class Encoder(nn.Module):
                def __init__(self, secret_size):
                    super(Encoder, self).__init__()
                    self.conv1 = nn.Conv2d(3 + secret_size, 32, 3, padding=1)
                    self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
                    self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
                    self.conv4 = nn.Conv2d(16, 3, 3, padding=1)

                def forward(self, image, secret):
                    secret = secret.view(-1, args.secret_size, 1, 1).repeat(1, 1, image.size(2), image.size(3))
                    x = torch.cat([image, secret], dim=1)
                    x = F.relu(self.conv1(x))
                    x = F.relu(self.conv2(x))
                    x = F.relu(self.conv3(x))
                    x = torch.tanh(self.conv4(x))
                    return x

            # 定义Decoder
            class Decoder(nn.Module):
                def __init__(self, secret_size):
                    super(Decoder, self).__init__()
                    self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
                    self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
                    self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
                    self.fc = nn.Linear(16 * img.size(2) * img.size(3), secret_size)

                def forward(self, x):
                    x = F.relu(self.conv1(x))
                    x = F.relu(self.conv2(x))
                    x = F.relu(self.conv3(x))
                    x = x.view(x.size(0), -1)
                    x = torch.sigmoid(self.fc(x))
                    return x

            # 初始化Encoder和Decoder
            encoder = Encoder(args.secret_size).to('cuda')
            decoder = Decoder(args.secret_size).to('cuda')

            # 训练Encoder和Decoder
            optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.ISSBA_lr)
            criterion = nn.BCELoss()

            for epoch in tqdm(range(args.ISSBA_epochs), desc="Training ISSBA"):
                secret = torch.randint(0, 2, (img.size(0), args.secret_size), dtype=torch.float32).to('cuda')

                optimizer.zero_grad()

                encoded = encoder(img, secret)
                decoded = decoder(img + encoded)

                loss = criterion(decoded, secret)
                loss.backward()

                optimizer.step()

                if (epoch + 1) % 10 == 0:  # 每10个epoch打印一次loss
                    print(f"Epoch [{epoch + 1}/{args.ISSBA_epochs}], Loss: {loss.item():.4f}")

            # 使用训练好的encoder生成触发器
            with torch.no_grad():
                secret = torch.randint(0, 2, (img.size(0), args.secret_size), dtype=torch.float32).to('cuda')
                residual = encoder(img, secret)
                poisoned_img = img + residual
                poisoned_img = torch.clamp(poisoned_img, 0, 1)

            # 如果原始输入不是tensor，转回numpy
            if not isinstance(img, torch.Tensor):
                print("Converting result to numpy array...")
                return poisoned_img.squeeze(0).cpu().numpy()
            else:
                #print("Returning result as CPU tensor...")
                return poisoned_img.squeeze(0).cpu()
        elif attack_type == "BATT":
            # 检查输入类型
            is_numpy = isinstance(img, np.ndarray)
            original_dim = img.dim() if isinstance(img, torch.Tensor) else img.ndim

            if is_numpy:
                img = torch.from_numpy(img).float()

            # 确保 img 是 4D tensor [N, C, H, W]
            if img.dim() == 3:
                img = img.unsqueeze(0)
            elif img.dim() == 4 and img.shape[1] == 1:
                # 处理 [1, 3, H, W] 的情况
                img = img.squeeze(0)
                img = img.unsqueeze(0)

            if img.shape[1] != 3 or img.dim() != 4:
                raise ValueError(f"Input image should have shape (N, 3, H, W) or (3, H, W), got {img.shape}")

            # 旋转图像 16 度
            poisoned_img = rotate(img, args.angle)

            # 恢复原始维度
            if original_dim == 3:
                poisoned_img = poisoned_img.squeeze(0)

            # 如果原始输入是 numpy 数组，转回 numpy
            if is_numpy:
                poisoned_img = poisoned_img.numpy()

            return poisoned_img

        elif attack_type == "Trojan":
            patch_size = args.patch_size
            mosaic_size = 1  # 马赛克块的大小
            # 检查图像是否为空
            if img.numel() == 0:
                print("Warning: Input image is empty")
                return img

            # 检查img是PyTorch张量还是NumPy数组
            is_torch_tensor = isinstance(img, torch.Tensor)

            # 创建马赛克pattern
            mosaic_pattern = np.random.rand(patch_size, patch_size, 3)
            mosaic_pattern = np.repeat(np.repeat(mosaic_pattern, mosaic_size, axis=0), mosaic_size, axis=1)

            # 如果img是PyTorch张量，将mosaic_pattern转换为PyTorch张量
            if is_torch_tensor:
                mosaic_pattern = torch.from_numpy(mosaic_pattern).float()
                if img.is_cuda:
                    mosaic_pattern = mosaic_pattern.cuda()
                # 确保维度匹配
                if img.dim() == 3:  # (C, H, W)
                    mosaic_pattern = mosaic_pattern.permute(2, 0, 1)
                elif img.dim() == 4:  # (N, C, H, W)
                    mosaic_pattern = mosaic_pattern.permute(2, 0, 1).unsqueeze(0)
            # 检查patch_size是否大于图像尺寸
            if patch_size > min(img.shape[-2:]):
                print(f"Warning: patch_size ({patch_size}) is larger than image dimensions {img.shape[-2:]}")
                patch_size = min(img.shape[-2:])

            # 将马赛克应用到图片右下角
            if is_torch_tensor:
                img[..., -patch_size:, -patch_size:] = mosaic_pattern[..., :patch_size, :patch_size]
            else:
                img[..., -patch_size:, -patch_size:] = mosaic_pattern[:patch_size, :patch_size, :]
            return img
        elif attack_type == "BadNet":
            img[:, -args.patch_size:, -args.patch_size:] = 1.0  # 在右下角添加白色方块
        elif attack_type == "Blended":
            # 获取图像的尺寸
            _, height, width = img.shape
            if args.use_hello:
                #print('使用Hello Kitty')
                # 创建与图像相同大小的 hello_kitty trigger
                trigger = torch.zeros_like(img)
                mask = torch.zeros_like(img)
                if 'cifar' in args.dataset.lower():
                    trigger[:, -32:, -32:] = hello_kitty_trigger
                    mask[:, -32:, -32:] = args.blend_ration
                else:
                    trigger[:, -64:, -64:] = hello_kitty_trigger_64
                    mask[:, -64:, -64:] = args.blend_ration
            else:
                #print('不使用Hello Kitty')
                # 创建与图像相同大小的 trigger
                trigger = torch.ones_like(img)
                trigger[:, :-args.patch_size, :-args.patch_size] = 0  # 只在右下角 patch_size x patch_size 区域为 1
                # 创建与输入图像相同大小的 mask
                mask = torch.zeros_like(img)
                mask[:, -args.patch_size:, -args.patch_size:] = args.blend_ration  # 注意这里使用 blend_ration

            img = img * (1 - mask) + trigger * mask
        return img

    # 对数据集进行投毒
    def poison_dataset(dataset, poison_rate, target_label, attack_type):
        poisoned_data = []
        poison_indices = set(np.random.choice(len(dataset), size=int(len(dataset) * poison_rate), replace=False))

        for idx, (img, label) in enumerate(dataset):
            if idx in poison_indices:
                img = create_trigger(img, attack_type)
                label = target_label
            poisoned_data.append((img, label))

        return poisoned_data, poison_indices

    # 对训练集进行投毒
    poisoned_train, train_poison_indices = poison_dataset(trainset, 0.01, args.target_label, args.attack_type)

    # 对测试集进行投毒（使用固定的毒化率，例如1%）
    poisoned_test, test_poison_indices = poison_dataset(testset, 0.01, args.target_label, args.attack_type)

    # 创建自定义数据集类
    class PoisonedDataset(torch.utils.data.Dataset):
        def __init__(self, data):
            self.data = data

        def __getitem__(self, index):
            return self.data[index]


        def __len__(self):
            return len(self.data)

    # 将投毒后的数据包装成 Dataset 对象
    poisoned_train_dataset = PoisonedDataset(poisoned_train)
    poisoned_test_dataset = PoisonedDataset(poisoned_test)

    return poisoned_train_dataset, poisoned_test_dataset, test_poison_indices

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def autoencoder_main(args, autoencoder_params, poisoned_trainset):
    device = 'cuda'
    autoencoder = Autoencoder().to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=autoencoder_params['learning_rate'])
    train_loader = DataLoader(poisoned_trainset, batch_size=autoencoder_params['batch_size'], shuffle=True)

    print("Training autoencoder...")
    for epoch in range(autoencoder_params['epochs']):
        epoch_loss = 0
        for data, _ in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{autoencoder_params['epochs']}"):
            img = data.to(device)
            output = autoencoder(img)
            loss = criterion(output, img)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch [{epoch + 1}/{autoencoder_params["epochs"]}], Loss: {epoch_loss / len(train_loader):.4f}')

    # 步骤 3: 使用自编码器清洗数据(对数据进行去后门处理，不会删掉数据)
    def clean_data(dataset, autoencoder):
        cleaned_data, cleaned_targets = [], []
        autoencoder.eval()
        with torch.no_grad():
            for img, label in tqdm(dataset, desc="Cleaning data"):
                img = img.unsqueeze(0).to(device)
                cleaned_img = autoencoder(img)
                cleaned_data.append(cleaned_img.squeeze(0).cpu())
                cleaned_targets.append(label)
        return TensorDataset(torch.stack(cleaned_data), torch.tensor(cleaned_targets))

    print("Cleaning training set...")
    cleaned_trainset = clean_data(poisoned_trainset, autoencoder)
    return cleaned_trainset
#autoencoder进行防御，数据清洗，然后返回清洗后的数据
def autoencoder_defense(args, autoencoder_params):
    # 使用 load_and_poison_data 函数来加载和毒化数据
    poisoned_trainset, poisoned_testset, test_poison_indices = load_and_poison_data(args)
    cleaned_trainset = autoencoder_main(args, autoencoder_params, poisoned_trainset)
    return cleaned_trainset, poisoned_testset, test_poison_indices

#ShrinkPad部分
class ShrinkPad:
    def __init__(self, size_map, pad):
        self.size_map = size_map
        self.pad = pad

    def preprocess(self, img):
        c, h, w = img.shape

        # Shrink the image
        shrink_size = self.size_map - self.pad
        shrunk_image = F.interpolate(img.unsqueeze(0), size=(shrink_size, shrink_size), mode='bilinear',
                                     align_corners=False).squeeze(0)

        # Calculate random padding
        pad_top = random.randint(0, self.pad)
        pad_bottom = self.pad - pad_top
        pad_left = random.randint(0, self.pad)
        pad_right = self.pad - pad_left

        # Apply padding
        padded_image = F.pad(shrunk_image.unsqueeze(0), (pad_left, pad_right, pad_top, pad_bottom), mode='constant',
                             value=0).squeeze(0)

        return padded_image


def shrinkpad_main(args, shrinkpad_params, poisoned_trainset):
    shrinkpad_defense = ShrinkPad(size_map=shrinkpad_params['size_map'], pad=shrinkpad_params['pad'])
    def shrinkpad_transform(img):
        if not isinstance(img, torch.Tensor):
            img = transforms.ToTensor()(img)

        # 应用ShrinkPad
        img_tensor = shrinkpad_defense.preprocess(img)

        # 确保值在[0, 1]范围内
        img_tensor = torch.clamp(img_tensor, 0, 1)

        return img_tensor

    # 处理原始数据集
    processed_data = []
    for img, label in poisoned_trainset:
        processed_img = shrinkpad_transform(img)
        processed_data.append((processed_img, label))

    # 创建一个新的数据集，包含处理后的图像
    class ProcessedDataset(torch.utils.data.Dataset):
        def __init__(self, data):
            self.data = data

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return len(self.data)

    clean_trainset = ProcessedDataset(processed_data)

    return clean_trainset

def shrinkpad_defense(args, shrinkpad_params):
    poisoned_trainset, poisoned_testset, test_poison_indices = load_and_poison_data(args)
    cleaned_trainset = shrinkpad_main(args, shrinkpad_params, poisoned_trainset)
    return cleaned_trainset, poisoned_testset, test_poison_indices



import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

class SCALE_UP:
    def __init__(self, model, processor, scale_set, retention_rate=0.9, valset=None):
        self.model = model
        self.processor = processor
        self.scale_set = scale_set
        self.retention_rate = retention_rate
        self.valset = valset
        self.threshold = None
        self.mean = None
        self.std = None
        if self.valset:
            self.compute_statistics()

    def compute_statistics(self):
        val_loader = DataLoader(self.valset, batch_size=128, shuffle=False)
        all_spc_scores = []

        for images, _ in val_loader:
            images = images.to(next(self.model.parameters()).device)
            spc = self.compute_spc(images)
            all_spc_scores.extend(spc.cpu().numpy())

        all_spc_scores = np.array(all_spc_scores)
        self.mean = np.mean(all_spc_scores)
        self.std = np.std(all_spc_scores)

    def compute_spc(self, images):
        self.model.eval()
        with torch.no_grad():
            original_inputs = self.processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to(self.model.device)
            original_features = self.model.get_image_features(pixel_values=original_inputs)
            original_outputs = self.model.classifier(original_features)
            original_pred = torch.argmax(original_outputs, dim=1)

            spc = torch.zeros(images.size(0), device=self.model.device)
            for scale in self.scale_set:
                scaled_images = torch.clamp(images * scale, 0, 1)
                scaled_inputs = self.processor(images=scaled_images, return_tensors="pt", do_rescale=False).pixel_values.to(self.model.device)
                scaled_features = self.model.get_image_features(pixel_values=scaled_inputs)
                scaled_outputs = self.model.classifier(scaled_features)
                scaled_pred = torch.argmax(scaled_outputs, dim=1)
                spc += (scaled_pred == original_pred).float()

            spc /= len(self.scale_set)
        return spc

    def detect(self, images):
        spc = self.compute_spc(images)
        if self.valset:
            spc = (spc - self.mean) / self.std
        return spc < self.threshold

def SCALE_UP_defense(args, scale_up_params, train_dataset):
    # 加载验证集
    transform = transforms.Compose([transforms.ToTensor()])
    if args.dataset == 'CIFAR-10':
        valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif args.dataset == 'CIFAR-100':
        valset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    else:
        _, valset = load_tiny_imagenet_data()

    # 创建SCALE_UP实例
    scale_up = SCALE_UP(
        model=scale_up_params['model'],
        processor=scale_up_params['processor'],
        scale_set=scale_up_params['scale_set'],
        retention_rate=scale_up_params['retention_rate'],
        valset=valset
    )

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)

    # 收集所有样本的SPC分数
    all_spc_scores = []
    for images, _ in tqdm(train_loader, desc="Computing SPC scores"):
        images = images.to(next(scale_up.model.parameters()).device)
        batch_spc_scores = scale_up.compute_spc(images)
        all_spc_scores.extend(batch_spc_scores.cpu().numpy())

    # 根据要求至少删除50%的样本，设置阈值
    threshold = np.percentile(all_spc_scores, 50)  # 设置为中位数
    scale_up.threshold = threshold

    # 应用阈值筛选样本
    cleaned_indices = [i for i, score in enumerate(all_spc_scores) if score > threshold]

    # 确保至少删除50%的样本
    if len(cleaned_indices) > len(train_dataset) // 2:
        # 如果保留的样本超过50%，随机删除一些样本直到达到50%
        num_to_remove = len(cleaned_indices) - (len(train_dataset) // 2)
        np.random.shuffle(cleaned_indices)
        cleaned_indices = cleaned_indices[num_to_remove:]

    # 创建清洗后的训练集
    cleaned_trainset = Subset(train_dataset, cleaned_indices)

    print(f"原始数据集大小: {len(train_dataset)}, 清洗后数据集大小: {len(cleaned_trainset)}")
    print(f"删除样本比例: {1 - len(cleaned_trainset) / len(train_dataset):.2%}")

    return cleaned_trainset


def ABL_defense(args, ABL_params):
    poisoned_trainset, poisoned_testset, test_poison_indices = load_and_poison_data(args)
    cleaned_trainset = ABL_main(ABL_params, poisoned_trainset)
    return cleaned_trainset, poisoned_testset, test_poison_indices
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
import random
from tqdm import tqdm

class LGALoss(nn.Module):
    def __init__(self, loss, gamma):
        super().__init__()
        self.loss = loss
        self.gamma = gamma

    def forward(self, logits, targets):
        loss = self.loss(logits, targets)
        loss = torch.sign(loss - self.gamma) * loss
        return loss


def ABL_main(ABL_para, poisoned_trainset):
    torch.manual_seed(ABL_para['seed'])
    np.random.seed(ABL_para['seed'])
    random.seed(ABL_para['seed'])
    if ABL_para['data']=='imagenet-tiny':
        k=16
    else:
        k=8
    model = nn.Sequential(
        nn.Conv2d(3, 16, 3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(16, 32, 3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(32 * k * k, ABL_para['num'])
    )

    device = 'cuda'
    model = model.to(device)

    criterion = nn.CrossEntropyLoss(reduction='mean')  # 确保是标量损失
    optimizer = optim.SGD(model.parameters(), lr=ABL_para['lr'], momentum=ABL_para['momentum'],
                          weight_decay=ABL_para['weight_decay'])

    lga_loss = LGALoss(criterion, ABL_para['gamma'])
    train_loader = DataLoader(poisoned_trainset, batch_size=ABL_para['batch_size'], shuffle=True,
                              num_workers=ABL_para['num_workers'])

    print("Stage 1: Pre-training with LGA loss")
    for epoch in tqdm(range(ABL_para['pre_epochs']), desc="Pre-training"):
        for batch in train_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = lga_loss(outputs, targets)
            loss.backward()
            optimizer.step()

    print("Stage 2: Identifying potential poisoned samples")
    model.eval()
    losses = []

    # 损失函数重新定义，用于单样本损失计算
    criterion_no_reduction = nn.CrossEntropyLoss(reduction='none')  # 每个样本的损失

    with torch.no_grad():
        for inputs, targets in tqdm(train_loader, desc="Calculating losses"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            batch_losses = criterion_no_reduction(outputs, targets)  # 返回每个样本的损失
            losses.extend(batch_losses.cpu().numpy())  # 扩展每个样本的损失

    sorted_indices = np.argsort(losses)
    num_poison = int(len(poisoned_trainset) * ABL_para['split_ratio'])
    clean_indices = sorted_indices[num_poison:]

    clean_subset = Subset(poisoned_trainset, clean_indices)
    clean_loader = DataLoader(clean_subset, batch_size=ABL_para['batch_size'], shuffle=True,
                              num_workers=ABL_para['num_workers'])

    model = nn.Sequential(
        nn.Conv2d(3, 16, 3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(16, 32, 3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(32 * k * k, ABL_para['num'])
    ).to(device)

    optimizer = optim.SGD(model.parameters(), lr=ABL_para['lr'], momentum=ABL_para['momentum'],
                          weight_decay=ABL_para['weight_decay'])

    print("Stage 3: Re-training on clean data")
    for epoch in tqdm(range(ABL_para['clean_epochs']), desc="Clean training"):
        for batch in clean_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    print(f"ABL completed. {len(clean_subset)} samples identified as clean.")
    return clean_subset


