import os
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from models.resnet_models import *
import data_loaders
from torchattacks.attack import Attack
from tqdm import tqdm
from functions import TET_loss, seed_all, get_logger
from models.VGG_models import *
from models.WideResNet import *
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_all(1000)

# 数据加载
train_dataset, val_dataset, znorm = data_loaders.cifar_dataset(use_cifar10=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True,
                                               num_workers=16, pin_memory=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128,
                                              shuffle=False, num_workers=16, pin_memory=True)

# 选择模型，这里使用预训练的ResNet18
model = vgg11(num_classes=10, norm=znorm)
model.load_state_dict(torch.load('cifar10_vgg11_noise_lag60_rat.pth'))
model=model.to(device)
model.eval()  # 设置为评估模式

activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# 传递数据通过模型


def PGD(model, images, labels, eps=8/255, alpha=0.01, iters=50):
    r"""
    PGD (Projected Gradient Descent) Attack
    Args:
        model: The target model to attack.
        images: The input images.
        labels: The true labels of the images.
        eps: The maximum perturbation (epsilon).
        alpha: The step size for each iteration.
        iters: The number of iterations.
    Returns:
        adv_images: The adversarial examples generated by PGD.
    """
    model.eval()
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)

    loss = nn.CrossEntropyLoss()

    # Initialize adversarial examples as the original images
    adv_images = images.clone().detach()

    # Add small random noise to the initial images (optional, improves attack diversity)
    adv_images = adv_images + torch.empty_like(adv_images).uniform_(-eps, eps)
    adv_images = torch.clamp(adv_images, min=0, max=1).detach()

    for _ in range(iters):
        adv_images.requires_grad = True
        outputs,_= model(adv_images)
        outputs = outputs.mean(1)  # Assuming the same mean operation as in FGSM
        cost = loss(outputs, labels)

        # Compute gradients
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]

        # Update adversarial examples with gradient ascent
        adv_images = adv_images.detach() + alpha * grad.sign()

        # Project adversarial examples back to the epsilon-ball and valid range
        # adv_images = torch.clamp(adv_images, min=0, max=1)
        # adv_images = torch.max(torch.min(adv_images, images + eps), images - eps).detach()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()

    return adv_images



def val(model, test_loader, device, T):
    #atk = None
    #print(T)
    correct_ori = 0
    correct = 0
    total = 0
    model.eval()

    for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        inputs = inputs.to(device)
        ori = inputs
        inputs = PGD(model,inputs,targets)

        with torch.no_grad():
            if T > 0:
                outputs_ori,_ = model(ori)
                outputs,_ = model(inputs)
                outputs_ori=outputs_ori.mean(1)
                outputs = outputs.mean(1)
            else:
                outputs,_ = model(inputs)
                #print(outputs.shape)
        _, predicted_ori = outputs_ori.cpu().max(1)
        _, predicted = outputs.cpu().max(1)
        #print(predicted.shape)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())
        correct_ori += float(predicted_ori.eq(targets).sum().item())
    final_acc = 100 * correct / total
    final_acc_ori = 100 * correct_ori / total
    return final_acc,final_acc_ori


acc,acc_ori = val(model, test_loader, device, 8)
print(acc,acc_ori)


