import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#from networks import *
import torch.nn as nn
#import foolbox
import argparse
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import copy
import math
import numpy as np
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--device', choices=["cuda:0", "cuda:1", "cuda:2", "cuda:3"], type=str)
parser.add_argument('--method', choices=["fgsm", "ifgsm"], type=str)
parser.add_argument('--average', help="number of average when inference", type=int)

args = parser.parse_args()
device = torch.device(args.device)
method = args.method
average = args.average
import numpy as np
import math
from PIL import Image
import matplotlib.pyplot as plt

s=1*torch.ones(1)

def conv3x3(in_planes, out_planes, stride=1):
    " 3x3 convolution with padding "
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion=1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ELU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        #out=out+0.1*torch.std(out)*torch.randn_like(out)
        out = self.relu(out)

        return out

class LeNet_Cifar(nn.Module):   
    def __init__(self):
        super(LeNet_Cifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.weight = nn.Parameter(0.036*torch.randn(16,32,32))

        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2=  nn.Linear(120, 84)
        self.fc3 = nn.Linear(84,10)
 
    def forward(self, x,t):
        x = F.max_pool2d(F.relu(self.conv1(x+t*self.weight)), kernel_size=(2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2)
 
        x = x.view(x.size()[0], -1)
        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))
        x=self.fc3(x)
        return x


class ResNet_Cifar(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(ResNet_Cifar, self).__init__()
        self.inplanes = 16
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, stride=1, padding=1, bias=False)

        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ELU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64 * block.expansion, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    def forward(self, x,t):
        tt = torch.ones_like(x[:, :1, :, :]) * t   
        x = torch.cat([tt, x], 1)        
        x = self.conv1(x)

        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x=self.fc(x)

        return x


def lenet_cifar(**kwargs):
    model = LeNet_Cifar(**kwargs)
    return model

def resnet8_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [1, 1, 1], **kwargs)
    return model

def resnet20_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
    return model

def resnet32_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs)
    return model


def resnet44_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs)
    return model


def resnet56_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs)
    return model


def resnet110_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs)
    return model


s=torch.ones(1).to(device)

def fgsm(model, inputs, labels, epsilon=8 / 255, x_val_min=0, x_val_max=1):
    inputs_adv = inputs.clone().detach().requires_grad_()

    total_grad = torch.zeros_like(inputs_adv)
    outputs = model(inputs_adv,s)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    total_grad += torch.autograd.grad(loss, inputs_adv)[0]

    inputs_adv = inputs_adv + epsilon * total_grad.sign_()
    inputs_adv = torch.clamp(inputs_adv, x_val_min, x_val_max)
    inputs_adv = inputs_adv.clone().detach()

    outputs = model(inputs_adv,s)

    pred = outputs.argmax(dim=1)  # get the index of the max log-probability
    return (pred.eq(labels)).float().sum()

N=1

def i_fgsm(model, inputs, labels, epsilon=4 / 255, alpha=1 / 255, iteration=20, x_val_min=0,
           x_val_max=1):
    inputs_adv = inputs.clone().detach().requires_grad_()
    for i in range(iteration):
        outputs = model(inputs_adv,s)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        grad = torch.autograd.grad(loss, inputs_adv)[0]

        inputs_adv = inputs_adv + alpha * grad.sign_()
        inputs_adv = inputs + torch.clamp(inputs_adv - inputs, -epsilon, epsilon)
        inputs_adv = torch.clamp(inputs_adv, x_val_min, x_val_max)
        inputs_adv = inputs_adv.clone().detach().requires_grad_()

    outputs=+= model(inputs_adv,s)

    pred = outputs.argmax(dim=1)  # get the index of the max log-probability
    return (pred.eq(labels)).float().sum()

def i_fgm(model, inputs, labels, epsilon=0.5, alpha=0.1, iteration=50, x_val_min=0, x_val_max=1):
    inputs_adv = inputs.clone().detach().requires_grad_()
    for i in range(iteration):
        output = model(inputs_adv,s)
        loss = nn.CrossEntropyLoss()(output, labels)
        grad = torch.autograd.grad(loss, inputs_adv)[0]

        inputs_adv = inputs_adv + alpha * grad / grad.norm()

        norm = (inputs_adv - inputs).norm()
        factor = torch.min(torch.tensor(1.), (epsilon / norm).cpu()).to(device)
        inputs_adv = inputs + (inputs_adv - inputs) * factor

        inputs_adv = torch.clamp(inputs_adv, x_val_min, x_val_max)
        inputs_adv = inputs_adv.clone().detach().requires_grad_()

    outputs=+= model(inputs_adv,s)

    pred = outputs.argmax(dim=1)
    correct = pred.eq(labels).sum().item()
    return correct


def main():
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./data', train=False, download=True,
                                                               transform=transforms.ToTensor()),
                                              batch_size=100, shuffle=False, num_workers=4)

    #model = resnet110_cifar().eval().to(device)
    model = resnet56_cifar().eval()
    model.load_state_dict(torch.load('56-5-10.pt', map_location=lambda storage, loc: storage))
    model.to(device)
    total_correct = 0

    for batch_idx, (inputs, labels) in enumerate(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        if method == "i_fgm":
            correct = i_fgm(model, inputs, labels)
        else:
            correct = i_fgsm(model, inputs, labels)
        total_correct += correct
        print("[{}/{}] Accuracy:{}".format((batch_idx + 1) * 100, 10000, total_correct / (batch_idx + 1) / 100))


if __name__ == '__main__':
    main()
