from __future__ import print_function
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR
import torchattacks
from advertorch.attacks import LinfPGDAttack,CarliniWagnerL2Attack,DDNL2Attack,SpatialTransformAttack,GradientSignAttack
from autoattack import AutoAttack
#writer = SummaryWriter('runs/experiment_1')

q_level = 3
number_mediate_layers = 1

load_name = "ours/vgg16_helmet_1layers_T7_2bits_ouralgo90_1_8_2_o.pt"
train_step = [True,False,False]

step1_factor = 0
#print(step1_factor*10)
step2_factor = 1e-5
cfg = {
    'o' : [128,128,'M',256,256,'M',512,512,'M',(1024,0),'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


class Quantization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant=1):
        ctx.constant = constant
        new_x = torch.div(torch.floor(torch.mul(tensor, constant)), constant)
        return new_x

    @staticmethod
    def backward(ctx, grad_output):
        #print(grad_output)
        return 100*F.hardtanh(grad_output/100), None 

Quantization_ = Quantization.apply

class Clamp_q_(nn.Module):
    def __init__(self, min=0.0, max=1,q_level = q_level):
        super(Clamp_q_, self).__init__()
        self.min = min
        self.max = max
        self.q_level = q_level

    def forward(self, x):
        x = torch.clamp(x, min=self.min, max=self.max)
        x = Quantization_(x, self.q_level)
        return x

class CalculateLoss(torch.nn.Module):
    def __init__(self, q_level,c_max):
        super(CalculateLoss, self).__init__()
        self.q_level = q_level
        self.max_value=c_max

    def forward(self, x):
        q_level = self.q_level
        c_max = self.max_value
        # Simplifying calculations by combining conditions and using in-place operations where possible
        Safe_zero_mask = (x <= 0)
        Safe_one_mask =  (x >= c_max + 0.5/q_level)

        x_scaled = x * q_level
        k = 2 * torch.round(x_scaled - 0.5 - 1e-5) + 1  # Finds the nearest odd integer to x_scaled
        seq_val = (k * 0.5) / q_level

        # Using torch.where to combine operations and reduce memory usage
        seq_val = torch.where((x >= 0) &(x <= 1/q_level), 0, seq_val)
        seq_val = torch.where(Safe_zero_mask, x, seq_val)
        seq_val = torch.where(Safe_one_mask, x, seq_val)
        x = torch.where((x >= 0) &(x <= 1/q_level), 0.5*x, x)
        # Loss calculation
        act_loss =  torch.sum(torch.pow(torch.abs(x - seq_val),2)/2)

        #act_loss = (0.5 / q_level) ** 2 * torch.mean(torch.pow((torch.abs(x - seq_val) + 1e-10) / (0.5 / q_level), 0.5))

        return act_loss


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        #print(torch.randn(tensor.size()) * self.std + self.mean)
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class VGG_16(nn.Module):
    def __init__(self, vgg_name, quantize_factor=-1, clamp_max=1.0, bias=True):
        super(VGG_16, self).__init__()
        self.clamp_max = clamp_max
        self.bias = bias
        self.conv1 = nn.Conv2d(3, 64, 3, 1,1, bias=True)
        self.bn1 = nn.BatchNorm2d(64)
    
        self.relu_ = nn.ReLU()
        self.conv_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()

        #self.conv_layers.append(nn.Conv2d(3, 64, 3, 1, 1, bias=True))
        #self.bn_layers.append(nn.BatchNorm2d(64))
        
        for _ in range(number_mediate_layers):
            self.conv_layers.append(nn.Conv2d(64, 64, 3, 1, 1, bias=True))
            self.bn_layers.append(nn.BatchNorm2d(64))

        self.conv2 = nn.Conv2d(64, 3, 3, 1,1, bias=True)

        self.features = self._make_layers(cfg[vgg_name])
        self.cal = CalculateLoss(q_level,1)
        self.classifier = nn.Sequential(
            nn.Linear(512 , 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10),
        )
    
    def forward(self, x):

        x =self.conv1(x)
        x = self.bn1(x)
        #x = self.relu_(x)
        y = self.cal(x)
        x = torch.clamp(x, min=0, max=1)
        x = Quantization_(x,q_level)
        #y=0
        i=0
        for conv, bn in zip(self.conv_layers, self.bn_layers):
            i+=1
            if i== number_mediate_layers:
                x = conv(x)
                x = bn(x)
                #x=self.relu_(x)
                y += self.cal(x)
                x = torch.clamp(x, min=0, max=1)
                x = Quantization_(x,q_level)
            else:
                x = conv(x)
                x = bn(x)
                y += self.cal(x)
                x = torch.clamp(x, min=0, max=1)
                x = Quantization_(x,q_level)

        x =self.conv2(x)

        #accumulated_y = 0
        for layer in self.features:
            x = layer(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3

        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                padding = x[1] if isinstance(x, tuple) else 1
                out_channels = x[0] if isinstance(x, tuple) else x
                layers += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, bias=self.bias),nn.BatchNorm2d(out_channels),nn.ReLU()]
                in_channels = out_channels

        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
          


def r_fgsm_attack(image, epsilon, data_grad):
    epsilon = epsilon
    image = image.cuda()
    perturbed_image = image + torch.randn_like(image).sign() * epsilon
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    sign_data_grad = data_grad.sign()
    perturbed_image = perturbed_image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image






def test(model, device, test_loader,noise,use_function,clamp_max):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")

    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct

def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, -1, 1)
    return perturbed_image

def test_att(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_rfgsm(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        #noise = torch.FloatTensor(data.size()).uniform_(-8/255, 8/255)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = r_fgsm_attack(data, epsilon/2, data_grad)

        #data+=noisy_data
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_pytorch(model, device, test_loader, attack):
    model.eval()
    correct = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        perturbed_data = attack(data, target)
        
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_pytorch_AA(model, device, test_loader, attack):
    model.eval()
    correct = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        perturbed_data = attack.run_standard_evaluation(data, target)
        
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc



def test_random(model, device, test_loader,noise,use_function,clamp_max,epsilon=8/255):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    noise_ = torch.FloatTensor(data.size()).uniform_(-epsilon, epsilon).cuda()
                    data+=noise_
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")

    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct
def test_att_pytorch_TPAP(model, device, test_loader, attack):
    model.eval()
    correct = 0
    test_adv_total_re=0
    test_adv_correct_re=0
    test_adv_correct=0
    test_adv_total=0
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data_adv = attack(data, target)
        adv_output = model(data_adv.float())  

        prediction_adv_re = torch.max(adv_output, 1)
        test_adv_total_re += target.size(0)
        test_adv_correct_re += np.sum(prediction_adv_re[1].cpu().numpy() == target.cpu().numpy())

        # 对抗性净化
        process = torchattacks.FGSM(model, eps=8/255)

        data_adv = process(data_adv, torch.max(model(data_adv.float()), 1)[1]).to(device)
        
        # 分类结果
        criterion = nn.CrossEntropyLoss(reduction="mean").to(device)
        output_adv = model(data_adv.float())  # 测试对抗图像
        #loss_tadv = criterion(output_adv, target)  # 交叉熵损失
        #test_adv_loss += loss_tadv.item()  # 累加损失值
        prediction_adv = torch.max(output_adv, 1)  # second param "1" represents the dimension to be reduced
        test_adv_total += target.size(0)
        test_adv_correct += np.sum(prediction_adv[1].cpu().numpy() == target.cpu().numpy())

    print(test_adv_correct/test_adv_total, test_adv_correct_re/test_adv_total_re)
    #final_acc = correct / len(test_loader.dataset)
    #print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return 1
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=768, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=3, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--loss_scale', type=float, default=0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--noise', type=float, default=0, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--use_function', type=str, default='relu', metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--clamp_max', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--T', type=int, default=5, metavar='N',
                        help='SNN time window')
    parser.add_argument('--resume', type=str, default=None, metavar='RESUME',
                        help='Resume model from checkpoint')
                        
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    #torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 32,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    transform_clean=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                       transform=transform_train)
    
    for k in range(5):
        for i in range(1):
            transform_train_1 = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.RandomCrop(32, padding = 6),
                transforms.RandomHorizontalFlip(),
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                AddGaussianNoise(std=0.01),
                #AddQuantization()
            ])

        dataset1 = dataset1+ datasets.CIFAR10('../data', train=True, download=True,
                       transform=transform_train_1)
    
    dataset2 = datasets.CIFAR10('../data', train=False,
                       transform=transform)
    #print(type(dataset1[0][0]))
    train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2,batch_size=128) 

    model = VGG_16('VGG16', clamp_max=1,bias =True).to(device)
    model.load_state_dict(torch.load(load_name), strict=False)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr,weight_decay = 0.001)
    scheduler = MultiStepLR(optimizer, milestones=[50, 70, 90], gamma=0.2)    
    ACC = 0
    ACC_70 = 0
    ACC_90 = 0
    #test(model, device, test_loader, args.noise,args.use_function,args.clamp_max)
    epsilon = 8/255
    #acc = test_att(model, device, test_loader, epsilon)
    


    print("------------------------------------------------------------")
    print("clean")
    test(model, device, test_loader, args.noise,args.use_function,args.clamp_max)
    print("------------------------------------------------------------")
    print("random")
    test_random(model, device, test_loader, args.noise,args.use_function,args.clamp_max,epsilon)
    #epsilon = 8/255
    #acc = test_att(model, device, test_loader, epsilon)
    print("------------------------------------------------------------")
    print("fgsm")
    epsilon = 8/255
    attack = torchattacks.FGSM(model, eps=epsilon)
    test_att_pytorch_TPAP(model, device, test_loader, attack)
    print("------------------------------------------------------------")
    print("fgsm+R")
    acc = test_att_rfgsm(model, device, test_loader, epsilon)

    print("------------------------------------------------------------")
    print("PGD-20")
    attack = torchattacks.PGD(model, eps=8/255, alpha=0.003, steps=20)    
    test_att_pytorch_TPAP(model, device, test_loader, attack)

    print("------------------------------------------------------------")
    print("CW")
    #attack = CarliniWagnerL2Attack(
    #            model, 10, clip_min=0.0, clip_max=1.0, max_iterations=10, confidence=1, initial_const=1, learning_rate=1e-2,
    #            binary_search_steps=4, targeted=False)
    attack = torchattacks.CW(model, c=1, kappa=1, steps=10, lr=0.01)   
    test_att_pytorch_TPAP(model, device, test_loader, attack)

    print("------------------------------------------------------------")   
    print("DDN") 
    #attack = torchattacks.DDN(model, steps=20, gamma=0.05, init_norm=1.0, max_norm=2.0)

    attack = DDNL2Attack(model, nb_iter=20, gamma=0.05, init_norm=1.0, quantize=True, levels=16, clip_min=0.0,
                        clip_max=1.0, targeted=False, loss_fn=None)  

    #print("DDN")
    test_att_pytorch_TPAP(model, device, test_loader, attack)

if __name__ == '__main__':
    main()





