﻿# -*- coding: utf-8 -*-
"""
@author: admin
"""

import os
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
# import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
# from Tools import filters, JSMA
from torchvision import utils as vutils
import numpy as np
from collections import OrderedDict
import random
# from frank_wolfe import FrankWolfe
# from autoattack import AutoAttack
# import advertorch
from advertorch.attacks import LinfPGDAttack, CarliniWagnerL2Attack, DDNL2Attack, SinglePixelAttack, LocalSearchAttack, SpatialTransformAttack,L1PGDAttack
# from autoattack import AutoAttack
#from models.LeNet import LeNet5_CIFAR10
#from models.ResNet import ResNet18_cifar, ResNet_autoencoder, ResidualBlock, Classifier_head, ResNet_Encoder, Decoder, ResNet18_Dnet, ResNet50, ResNet101
#from models.AMA_densenet import ft_net_dense
import advertorch
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
# from models.pix2pix_unet import define_G, define_D, GANLoss, get_scheduler, update_learning_rate
import time
from timm.models import create_model
import copy

# basic settings
seed = 5
torch.manual_seed(seed)  # 为CPU设置随机种子
torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现。
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
NUM_EPOCHS = 300
LEARNING_RATE = 1e-3
BATCH_SIZE = 256


# def save_decoded_image(img, name):
#     img = img.view(img.size(0), 3, 32, 32)
#     save_image(img, name)
#

###preprocess###
transform = transforms.Compose(
    [
        transforms.Resize([32, 32]),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465),
        #                      (0.2023, 0.1994, 0.2010)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))

     ]
)

trainset = datasets.CIFAR10(
    root='/mnt/data/zyhhh/datasets/CIFAR10',
    train=True,
    download=True,
    transform=transform
)
testset = datasets.CIFAR10(
    root='/mnt/data/zyhhh/datasets/CIFAR10',
    train=False,
    download=True,
    transform=transform
)
trainloader = DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=False
)
testloader = DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=False
)


# target_model = torch.load('/mnt/8T-data/zyhhh/project/test_code/saving_models/supplement/resnet18_cifar10.pkl').eval()
# target_model = torch.load('/mnt/8T-data/zyhhh/project/test_code/saving_models/attack_eps_cifar10/vanilla/CIFAR10_wideresnet34_none.pkl')

# FGSM_N = advertorch.attacks.GradientSignAttack(predict=target_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3)
#
# FGSM_T = advertorch.attacks.GradientSignAttack(predict=target_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
# targeted=True)
#
# BIM_N = advertorch.attacks.LinfBasicIterativeAttack(predict=target_model, loss_fn=nn.CrossEntropyLoss(), eps=0.3, nb_iter=10,
#                                                     eps_iter=0.05,
#                                                     clip_min=0.0, clip_max=1.0, targeted=False)
# BIM_T = advertorch.attacks.LinfBasicIterativeAttack(predict=target_model, loss_fn=nn.CrossEntropyLoss(), eps=0.3, nb_iter=10,
#                                                     eps_iter=0.05,
#                                                     clip_min=0.0, clip_max=1.0, targeted=True)
#
#
# MMT_N = advertorch.attacks.MomentumIterativeAttack(predict=target_model, loss_fn=nn.CrossEntropyLoss(), eps=0.3, nb_iter=10,
#                                                    decay_factor=1.0, eps_iter=0.01,
#                                                    clip_min=0.0, clip_max=1.0, targeted=False)
# MMT_T = advertorch.attacks.MomentumIterativeAttack(predict=target_model, loss_fn=nn.CrossEntropyLoss(), eps=0.3, nb_iter=10,
#                                                    decay_factor=1.0, eps_iter=0.01,
#                                                    clip_min=0.0, clip_max=1.0, targeted=True)
#
#
# PGD_N = LinfPGDAttack(
#             target_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
#             nb_iter=10, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False)
#
# PGD_T = LinfPGDAttack(
#             target_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
#             nb_iter=10, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=True)
#
#
# CW_N = CarliniWagnerL2Attack(
#     target_model, 10, clip_min=0.0, clip_max=1.0, max_iterations=500, confidence=1, initial_const=1, learning_rate=1e-2,
#     binary_search_steps=4, targeted=False)
#
# CW_T = CarliniWagnerL2Attack(
#     target_model, 10, clip_min=0.0, clip_max=1.0, max_iterations=500, confidence=1, initial_const=1, learning_rate=1e-2,
#     binary_search_steps=4, targeted=True)
#
# CW_N = CarliniWagnerL2Attack(
#     target_model, 10, clip_min=0.0, clip_max=1.0, max_iterations=100, confidence=1, initial_const=1, learning_rate=1e-2,
#     binary_search_steps=4, targeted=False)
#
# CW_T = CarliniWagnerL2Attack(
#     target_model, 10, clip_min=0.0, clip_max=1.0, max_iterations=100, confidence=1, initial_const=1, learning_rate=1e-2,
#     binary_search_steps=4, targeted=True)
#
#
# DDN = DDNL2Attack(target_model, nb_iter=1000, gamma=0.05, init_norm=1.0, quantize=True, levels=256, clip_min=0.0,
#                         clip_max=1.0, targeted=False, loss_fn=None)
#
# STA = SpatialTransformAttack(
#     target_model, 10, clip_min=0.0, clip_max=1.0, max_iterations=5000, search_steps=20, targeted=False)
#
# AA_N = AutoAttack(target_model, norm='Linf', eps=0.3, version='standard')
#
# JSMA_T = advertorch.attacks.JacobianSaliencyMapAttack(predict=target_model, num_classes=10, clip_min=0.0, clip_max=1.0,
#                                                     loss_fn=None, theta=1.0, gamma=1.0,
#                                                     comply_cleverhans=False)


'''
教师模型
'''
# normal教师
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/normal_teacher/resnet18/best_model.pth').cuda().eval() # [none=95.03, fgsm=50.09, pgd=11.72]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/normal_teacher/resnet18/last_model.pth').cuda().eval() # [none=94.74, fgsm=50.46, pgd=12.08]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/normal_teacher/resnet18/best_model_atpgd.pth').cuda().eval() # [none=91.15, fgsm=77.83, pgd=56.92]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/normal_teacher/resnet18/last_model_atpgd.pth').cuda().eval() # [none=90.41, fgsm=77.68, pgd=55.83]



# nasty教师
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/nasty_teacher/nasty_resnet18/best_model.pth') # [none=94.37, fgsm=51.75, pgd=26.45]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/nasty_teacher/nasty_resnet18/last_model.pth') # [none=94.15, fgsm=52.05, pgd=26.54]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/nasty_teacher/nasty_resnet18/best_model_atpgd_nasty.pth') # [none=90.85, fgsm=82.86, pgd=62.37]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/nasty_teacher/nasty_resnet18/last_model_atpgd_nasty.pth') # [none=90.48, fgsm=82.25, pgd=61.86]

# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments_2.1/CIFAR10/nasty_model/nasty_wrn3410/best_model_atpgd_nasty.pth')

'''
学生模型
'''
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_normal_resnet18/cnn/best_model.pth').cuda().eval() # [none=87.47, fgsm=39.57, pgd=13.31]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_normal_resnet18/cnn/last_model.pth').cuda().eval() # [none=87.83, fgsm=39.09, pgd=13.03]

# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_normal_resnet18/cnn/best_model_kd(adv)_from_normal_atpgd.pth') # [none=74.43, fgsm=45.70, pgd=35.20], old:[none=74.32, fgsm=47.63, pgd=20.66]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_normal_resnet18/cnn/last_model_kd(adv)_from_normal_atpgd.pth') # [none=62.64, fgsm=54.30, pgd=43.45], old:[none=67.22, fgsm=50.78, pgd=21.84]

# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_normal_resnet18/cnn/best_model_kd(adv)_from_normal_atpgd_2.pth') # [none=82.31, fgsm=55.32, pgd=29.67]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_normal_resnet18/cnn/last_model_kd(adv)_from_normal_atpgd_2.pth') # [none=81.10, fgsm=53.57, pgd=29.85]

# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_nasty_resnet18/cnn/best_model_kd(adv)_from_nasty_atpgd.pth') # [none= ,fgsm=, pgd=]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments/CIFAR10/kd_from_nasty_resnet18/cnn/last_model_kd(adv)_from_nasty_atpgd.pth') # [none= ,fgsm=, pgd=]


'''
model with EDM
'''
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments_EDM/CIFAR10/nasty_model/nasty_wrn3410_with_EDM/best_model_atpgd_nasty_with_EDM_1m.pth').module # [none= ,fgsm=, pgd=]
test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments_EDM/CIFAR10/nasty_model/nasty_wrn3410_with_EDM/last_model_atpgd_nasty_with_EDM_1m.pth').module # [none= ,fgsm=, pgd=]

# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments_EDM/CIFAR10/nasty_model/nasty_wrn3410_with_EDM/best_model_atpgd_nasty_with_EDM_5m.pth').module # [none= ,fgsm=, pgd=]
# test_model = torch.load('/mnt/data/zyhhh/project/Nasty-Teacher-main_162/experiments_EDM/CIFAR10/nasty_model/nasty_wrn3410_with_EDM/last_model_atpgd_nasty_with_EDM_5m.pth').module # [none= ,fgsm=, pgd=]



test_model.eval()
total = 0
correct = 0
temp = 0
targetlabel = torch.zeros(BATCH_SIZE).cuda()
targetlabel = targetlabel.to('cuda', dtype=torch.int64)
for img, label in testloader:
    img, label = img.cuda(), label.cuda()
    targetlabel_temp = targetlabel[0:img.shape[0]]

    # img_test = img

    # FGSM_N = advertorch.attacks.GradientSignAttack(predict=test_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=8/255)
    # img_test = FGSM_N.perturb(img)

    PGD_N = LinfPGDAttack(
        test_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=8/255,
        nb_iter=10, eps_iter=2/255, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False)
    img_test = PGD_N.perturb(img)

    # AA_N = AutoAttack(test_model, norm='Linf', eps=8/255, version='standard')
    # img_test = AA_N.run_standard_evaluation(img, label, bs=BATCH_SIZE)



    x = test_model(img_test)

    _, prediction = torch.max(x, 1)
    total += label.size(0)
    correct += (prediction == label).sum()
    # print('当前temp:', temp, '/312,当前batch正确的个数:', (prediction == label).sum())
    temp += 1
print('Accuracy=%.2f' % (100.00 * correct.item() / total))
# print('There are ' + str(correct.item()) + ' correct pictures.')





















'''
from deeprobust.image.attack.pgd import PGD
from deeprobust.image.attack.fgsm import FGSM
model = torch.load('/mnt/8T-data/zyhhh/project/test_code/saving_models/deeprobust/CIFAR10_resnet18_atpgd_deeprobust.pkl')
def adv_data(model, data, output, ep=0.3, num_steps=10, perturb_step_size=0.01):
    """
    Generate input(adversarial) data for training.
    """

    adversary = PGD(model)
    data_adv = adversary.generate(data, output.flatten(), epsilon=ep, num_steps=num_steps, step_size=perturb_step_size)
    output = model(data_adv)  # output就只是对对抗数据的output，没有良性样本
    return data_adv, output

def calculate_loss(output, target, redmode = 'mean'):
        """
        Calculate loss for training.
        """

        loss = F.cross_entropy(output, target, reduction = redmode)
        return loss


test_loss = 0
correct = 0
test_loss_adv = 0
correct_adv = 0
for data, target in testloader:
    data, target = data.cuda(), target.cuda()

    # print clean accuracy
    output = model(data)
    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()

    # print adversarial accuracy
    print(data.shape, target.shape)
    data_adv, output_adv = adv_data(model=model, data=data, output=target, ep=0.3, num_steps=10, perturb_step_size=0.01)

    test_loss_adv += calculate_loss(output_adv, target, redmode='sum').item()  # sum up batch loss
    pred_adv = output_adv.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct_adv += pred_adv.eq(target.view_as(pred_adv)).sum().item()

test_loss /= len(testloader.dataset)
test_loss_adv /= len(testloader.dataset)

print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(testloader.dataset),
    100. * correct / len(testloader.dataset)))

print('\nTest set: Adv loss: {:.3f}, Adv Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss_adv, correct_adv, len(testloader.dataset),
    100. * correct_adv / len(testloader.dataset)))
'''






