from __future__ import print_function
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR
import torchattacks
import torchattacks
from advertorch.attacks import LinfPGDAttack,CarliniWagnerL2Attack,DDNL2Attack,SpatialTransformAttack,GradientSignAttack
from autoattack import AutoAttack
#writer = SummaryWriter('runs/experiment_1')

q_level = 3
number_mediate_layers =2
#load_name = "../baseline/chenyao_baseline_relu.pt"
load_name = "2layers_50_q2_t3_ours90_1.pt"
train_step = [True,False,False]

step1_factor = 0.001
#print(step1_factor*10)
step2_factor = 2e-6
cfg = {
    'o' : [128,128,'M',256,256,'M',512,512,'M',(1024,0),'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


class Quantization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant=1):
        ctx.constant = constant
        new_x = torch.div(torch.floor(torch.mul(tensor, constant)), constant)
        return new_x

    @staticmethod
    def backward(ctx, grad_output):
        #print(grad_output)
        return 100*F.hardtanh(grad_output/100), None 

Quantization_ = Quantization.apply

class Clamp_q_(nn.Module):
    def __init__(self, min=0.0, max=1,q_level = q_level):
        super(Clamp_q_, self).__init__()
        self.min = min
        self.max = max
        self.q_level = q_level

    def forward(self, x):
        x = torch.clamp(x, min=self.min, max=self.max)
        x = Quantization_(x, self.q_level)
        return x

class CalculateLoss(torch.nn.Module):
    def __init__(self, q_level,c_max):
        super(CalculateLoss, self).__init__()
        self.q_level = q_level
        self.max_value=c_max

    def forward(self, x):
        q_level = self.q_level
        c_max = self.max_value
        # Simplifying calculations by combining conditions and using in-place operations where possible
        Safe_zero_mask = (x <= 0)
        Safe_one_mask =  (x >= c_max + 0.5/q_level)

        x_scaled = x * q_level
        k = 2 * torch.round(x_scaled - 0.5 - 1e-5) + 1  # Finds the nearest odd integer to x_scaled
        seq_val = (k * 0.5) / q_level

        # Using torch.where to combine operations and reduce memory usage
        seq_val = torch.where((x >= 0) &(x <= 1/q_level), 0, seq_val)
        seq_val = torch.where(Safe_zero_mask, x, seq_val)
        seq_val = torch.where(Safe_one_mask, x, seq_val)
        x = torch.where((x >= 0) &(x <= 1/q_level), 0.5*x, x)
        # Loss calculation
        act_loss =  torch.sum(torch.pow(torch.abs(x - seq_val),2)/2)

        #act_loss = (0.5 / q_level) ** 2 * torch.mean(torch.pow((torch.abs(x - seq_val) + 1e-10) / (0.5 / q_level), 0.5))

        return act_loss


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        #print(torch.randn(tensor.size()) * self.std + self.mean)
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)
        out = F.relu(out)
        return out
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.downsample(x)
        out = F.relu(out)

        return out
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1,1, bias=True)
        self.bn1 = nn.BatchNorm2d(64)
    

        self.conv_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()

        #self.conv_layers.append(nn.Conv2d(3, 64, 3, 1, 1, bias=True))
        #self.bn_layers.append(nn.BatchNorm2d(64))
        
        for _ in range(number_mediate_layers):
            self.conv_layers.append(nn.Conv2d(64, 64, 3, 1, 1, bias=True))
            self.bn_layers.append(nn.BatchNorm2d(64))

        self.conv2 = nn.Conv2d(64, 3, 3, 1,1, bias=True)

        self.conv = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.cal = CalculateLoss(q_level,1)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        
        x =self.conv1(x)
        x = self.bn1(x)
        y = self.cal(x)
        x = torch.clamp(x, min=0, max=1)
        x = Quantization_(x,q_level)

        for conv, bn in zip(self.conv_layers, self.bn_layers):
            x = conv(x)
            x = bn(x)
            y += self.cal(x)
            x = torch.clamp(x, min=0, max=1)
            x = Quantization_(x,q_level)

        x =self.conv2(x)

        out = F.relu(self.bn(self.conv(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
          

def orthogonality_loss(model, beta):
    reg_loss = 0.0
    for name, param in model.named_parameters():
        if 'weight' in name and len(param.size()) > 1:
            #print(name)
            if len(param.size()) == 4:  

                W = param.view(param.size(0), -1)
                
            else:
                W = param

            WT_W = torch.matmul(W.T, W)  
            I = torch.eye(WT_W.size(0), device=param.device)  

            reg_loss += (WT_W - step1_factor * I).pow(2).sum()  

    return beta / 2.0 * reg_loss
def r_fgsm_attack(image, epsilon, data_grad):
    epsilon = epsilon
    image = image.cuda()
    perturbed_image = image + torch.randn_like(image).sign() * epsilon
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    sign_data_grad = data_grad.sign()
    perturbed_image = perturbed_image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    return perturbed_image


def test(model, device, test_loader,noise,use_function,clamp_max):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")

    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct

def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, -1, 1)
    return perturbed_image

def test_att(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_rfgsm(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        #noise = torch.FloatTensor(data.size()).uniform_(-8/255, 8/255)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = r_fgsm_attack(data, epsilon/2, data_grad)

        #data+=noisy_data
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_pytorch(model, device, test_loader, attack):
    model.eval()
    correct = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        perturbed_data = attack(data, target)
        
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_pytorch_AA(model, device, test_loader, attack):
    model.eval()
    correct = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        perturbed_data = attack.run_standard_evaluation(data, target)
        
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc



def test_random(model, device, test_loader,noise,use_function,clamp_max,epsilon=8/255):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    noise_ = torch.FloatTensor(data.size()).uniform_(-epsilon, epsilon).cuda()
                    data+=noise_
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")

    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=2048, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=3, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--loss_scale', type=float, default=0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--noise', type=float, default=0, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--use_function', type=str, default='relu', metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--clamp_max', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--T', type=int, default=5, metavar='N',
                        help='SNN time window')
    parser.add_argument('--resume', type=str, default=None, metavar='RESUME',
                        help='Resume model from checkpoint')
                        
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    #torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 32,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    transform_clean=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                       transform=transform_train)
    
    for k in range(5):
        for i in range(1):
            transform_train_1 = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.RandomCrop(32, padding = 6),
                transforms.RandomHorizontalFlip(),
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                AddGaussianNoise(std=0.01),
                #AddQuantization()
            ])

        dataset1 = dataset1+ datasets.CIFAR10('../data', train=True, download=True,
                       transform=transform_train_1)
    
    dataset2 = datasets.CIFAR10('../data', train=False,
                       transform=transform)
    #print(type(dataset1[0][0]))
    train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2,batch_size=512) 

    model = ResNet(BasicBlock, [2, 2, 2, 2]).cuda()
    model.load_state_dict(torch.load(load_name), strict=False)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr,weight_decay = 0.001)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)    
    epsilon = 8/255
    
    print("------------------------------------------------------------")
    print("clean")
    test(model, device, test_loader, args.noise,args.use_function,args.clamp_max)
    print("------------------------------------------------------------")
    print("random")
    test_random(model, device, test_loader, args.noise,args.use_function,args.clamp_max,epsilon)
    #epsilon = 8/255
    #acc = test_att(model, device, test_loader, epsilon)
    print("------------------------------------------------------------")
    print("fgsm")
    epsilon = 8/255
    attack = torchattacks.FGSM(model, eps=epsilon)
    test_att_pytorch(model, device, test_loader, attack)
    print("------------------------------------------------------------")
    print("fgsm+R")
    acc = test_att_rfgsm(model, device, test_loader, epsilon)
    print("------------------------------------------------------------")
    print("PGD-100")
    attack = torchattacks.PGD(model, eps=8/255, alpha=0.0007, steps=100)    
    test_att_pytorch(model, device, test_loader, attack)
    print("------------------------------------------------------------")
    print("PGD-20")
    attack = torchattacks.PGD(model, eps=8/255, alpha=0.003, steps=20)    
    test_att_pytorch(model, device, test_loader, attack)

    print("------------------------------------------------------------")
    print("CW")
    #attack = CarliniWagnerL2Attack(
    #            model, 10, clip_min=0.0, clip_max=1.0, max_iterations=10, confidence=1, initial_const=1, learning_rate=1e-2,
    #            binary_search_steps=4, targeted=False)
    attack = torchattacks.CW(model, c=1, kappa=1, steps=10, lr=0.01)   
    test_att_pytorch(model, device, test_loader, attack)

    print("------------------------------------------------------------")   
    print("DDN") 
    #attack = torchattacks.DDN(model, steps=20, gamma=0.05, init_norm=1.0, max_norm=2.0)

    attack = DDNL2Attack(model, nb_iter=20, gamma=0.05, init_norm=1.0, quantize=True, levels=16, clip_min=0.0,
                        clip_max=1.0, targeted=False, loss_fn=None)  

    #print("DDN")
    test_att_pytorch(model, device, test_loader, attack)

if __name__ == '__main__':
    main()





