from __future__ import print_function
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.optim.lr_scheduler import MultiStepLR
import torchattacks
from advertorch.attacks import LinfPGDAttack,CarliniWagnerL2Attack,DDNL2Attack,SpatialTransformAttack,GradientSignAttack
from autoattack import AutoAttack
from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter('runs/experiment_1')
#1 3 7 15
q_level = 3
load_name = "weight_/resnet_t3_noise_70.pt"
save_name = "weight/resnet_t3_noise.pt"
train_step = [True,False,False]
#train_step = [False,True,False]
#train_step = [False,False,True]
step1_factor = 0
step2_factor = 1e-5


class Quantization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant=1):
        ctx.constant = constant
        new_x = torch.div(torch.floor(torch.mul(tensor, constant)), constant)
        return new_x

    @staticmethod
    def backward(ctx, grad_output):
        #print(grad_output)
        return 100*F.hardtanh(grad_output/100), None 

Quantization_ = Quantization.apply

class Clamp_q_(nn.Module):
    def __init__(self, min=0.0, max=1,q_level = q_level):
        super(Clamp_q_, self).__init__()
        self.min = min
        self.max = max
        self.q_level = q_level

    def forward(self, x):
        x = torch.clamp(x, min=self.min, max=self.max)
        x = Quantization_(x, self.q_level)
        return x

class CalculateLoss(torch.nn.Module):
    def __init__(self, q_level,c_max):
        super(CalculateLoss, self).__init__()
        self.q_level = q_level
        self.max_value=c_max

    def forward(self, x):
        q_level = self.q_level
        c_max = self.max_value
        # Simplifying calculations by combining conditions and using in-place operations where possible
        Safe_zero_mask = (x <= 0)
        Safe_one_mask =  (x >= c_max + 0.5/q_level)

        x_scaled = x * q_level
        k = 2 * torch.round(x_scaled - 0.5 - 1e-5) + 1  # Finds the nearest odd integer to x_scaled
        seq_val = (k * 0.5) / q_level

        # Using torch.where to combine operations and reduce memory usage
        seq_val = torch.where((x >= 0) &(x <= 1/q_level), 0, seq_val)
        seq_val = torch.where(Safe_zero_mask, x, seq_val)
        seq_val = torch.where(Safe_one_mask, x, seq_val)
        x = torch.where((x >= 0) &(x <= 1/q_level), 0.5*x, x)
        # Loss calculation
        act_loss =  torch.sum(torch.pow(torch.abs(x - seq_val),2)/2)

        #act_loss = (0.5 / q_level) ** 2 * torch.mean(torch.pow((torch.abs(x - seq_val) + 1e-10) / (0.5 / q_level), 0.5))

        return act_loss


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        #print(torch.randn(tensor.size()) * self.std + self.mean)
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.clamp_ = Clamp_q_()
        self.downsample = nn.Sequential()
        self.calculate_loss = CalculateLoss(q_level, 1)

        if stride != 1 or in_planes != self.expansion*planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    """
    def forward(self, x):
        out = self.clamp_(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)
        out = self.clamp_(out)
        return out
    """
    def forward(self, x):
        out = self.bn1(self.conv1(x))
        self.bn1_output = out.clone()
        out = self.clamp_(out)

        out = self.bn2(self.conv2(out))
        self.bn2_output = out.clone()
        #out = self.clamp_(out)

        out += self.downsample(x)
        self.bn3_output = out.clone()
        out = self.clamp_(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.clamp_ = Clamp_q_()
        self.calculate_loss = CalculateLoss(q_level, 1)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    
    def forward(self, x):
        total_loss = 0
        out = self.bn(self.conv(x))
        total_loss += self.calculate_loss(out.clone())
        out = self.clamp_(out)

        out = self.layer1(out)
        out = self.layer2(out)

        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        
        return out
def r_fgsm_attack(image, epsilon, data_grad):
    epsilon = epsilon
    image = image.cuda()
    perturbed_image = image + torch.randn_like(image).sign() * epsilon
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    sign_data_grad = data_grad.sign()
    perturbed_image = perturbed_image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    return perturbed_image


def test(model, device, test_loader,noise,use_function,clamp_max):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")
    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct

def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, -1, 1)
    return perturbed_image

def test_att(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    print(f"Epsilon: {epsilon}\tTest Accuracy = {100*correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_rfgsm(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        #noise = torch.FloatTensor(data.size()).uniform_(-8/255, 8/255)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = r_fgsm_attack(data, epsilon/2, data_grad)

        #data+=noisy_data
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    print( correct / len(test_loader.dataset))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_pytorch(model, device, test_loader, attack):
    model.eval()
    correct = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        perturbed_data = attack(data, target)
        
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc

def test_att_pytorch_AA(model, device, test_loader, attack):
    model.eval()
    correct = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        perturbed_data = attack.run_standard_evaluation(data, target)
        
        output_perturbed = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    final_acc = correct / len(test_loader.dataset)
    print(final_acc)
    #print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc



def test_random(model, device, test_loader,noise,use_function,clamp_max,epsilon=8/255):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    noise_ = torch.FloatTensor(data.size()).uniform_(-epsilon, epsilon).cuda()
                    data+=noise_
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")
        final_acc = correct / len(test_loader.dataset)
        print(final_acc)
    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=384, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=3, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--loss_scale', type=float, default=0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--noise', type=float, default=0, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--use_function', type=str, default='relu', metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--clamp_max', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--T', type=int, default=5, metavar='N',
                        help='SNN time window')
    parser.add_argument('--resume', type=str, default=None, metavar='RESUME',
                        help='Resume model from checkpoint')
                        
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    #torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 32,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    transform_train = transforms.Compose([
        #transforms.RandomCrop(32, padding = 4),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    transform_clean=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    
    dataset2 = datasets.SVHN('../../../../robustness_qnn_input_01/data', split='test', download=True,
                       transform=transform)
    #print(type(dataset1[0][0]))
    test_loader = torch.utils.data.DataLoader(dataset2,batch_size=512) 

    model = ResNet(BasicBlock, [2, 2, 2, 2]).cuda()
    model.load_state_dict(torch.load(load_name), strict=False)

    epsilon = 8/255
    print("------------------------------------------------------------")
    print("clean")
    test(model, device, test_loader, args.noise,args.use_function,args.clamp_max)
    for i in range(10):
        print("------------------------------------------------------------")
        print("random")
        test_random(model, device, test_loader, args.noise,args.use_function,args.clamp_max,epsilon)
    #epsilon = 8/255
    #acc = test_att(model, device, test_loader, epsilon)
    print("------------------------------------------------------------")
    print("fgsm")
    epsilon = 8/255
    attack = torchattacks.FGSM(model, eps=epsilon)
    test_att_pytorch(model, device, test_loader, attack)
    print("------------------------------------------------------------")
    print("fgsm+R")
    acc = test_att_rfgsm(model, device, test_loader, epsilon)

    print("------------------------------------------------------------")
    print("PGD-20")
    attack = torchattacks.PGD(model, eps=8/255, alpha=0.003, steps=20)    
    test_att_pytorch(model, device, test_loader, attack)

    print("------------------------------------------------------------")
    print("CW")
    #attack = CarliniWagnerL2Attack(
    #            model, 10, clip_min=0.0, clip_max=1.0, max_iterations=10, confidence=1, initial_const=1, learning_rate=1e-2,
    #            binary_search_steps=4, targeted=False)
    attack = torchattacks.CW(model, c=1, kappa=1, steps=10, lr=0.01)   
    test_att_pytorch(model, device, test_loader, attack)

    print("------------------------------------------------------------")   
    print("DDN") 
    #attack = torchattacks.DDN(model, steps=20, gamma=0.05, init_norm=1.0, max_norm=2.0)

    attack = DDNL2Attack(model, nb_iter=20, gamma=0.05, init_norm=1.0, quantize=True, levels=16, clip_min=0.0,
                        clip_max=1.0, targeted=False, loss_fn=None)  

    #print("DDN")
    test_att_pytorch(model, device, test_loader, attack)




if __name__ == '__main__':
    main()





