import numpy as np
import scipy as sp
from scipy import stats
import random
import math
import pickle
import argparse

import torch, torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable, Function
from art.utils import load_dataset, check_and_transform_label_format
from torchvision import datasets, transforms


device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser(description='MNIST')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--mu', type=float, default=1.0, help='size of signal')
parser.add_argument('--rate', type=float, default=0.1, help='perturbation rate')
parser.add_argument('--N', type=int, default=100, help='number of training samples')
parser.add_argument('--dim', type=int, default=100, help='dimension')
parser.add_argument('--num1', default=0, type=int, help='mnist number1')
parser.add_argument('--num2', default=1, type=int, help='mnist number2')
args = parser.parse_args()

(x_train, y_train_onehot), (x_test, y_test_onehot), min_, max_ = load_dataset(str("mnist"))
x_train = np.float32(np.reshape(x_train, (-1, 784)))
x_test = np.float32(np.reshape(x_test, (-1, 784)))
y_train_onehot = np.float32(y_train_onehot)
y_test_onehot = np.float32(y_test_onehot)
y_train = np.argmax(y_train_onehot, axis=1)
y_test = np.argmax(y_test_onehot, axis=1)



class net(nn.Module):
    def __init__(self, d=784, width=1000):
        super(net, self).__init__()
        self.fc1 = nn.Linear(d, width, bias = True)
        self.fc2 = nn.Linear(width, 1, bias = True)
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class Logistic_Loss(torch.nn.Module):
    def __init__(self):
        super(Logistic_Loss, self).__init__()

    def forward(self, inputs, target):
        L = torch.log(1 + torch.exp(-target*inputs.t()))
        return torch.mean(L)


def norms(Z):
    """Compute norms over all but the first dimension"""
    return Z.view(Z.shape[0], -1).norm(dim=1)[:,None]
def pgd_l2(model, X, y, criterion, epsilon, alpha, num_iter):
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = criterion(model(X + delta), y)
        loss.backward()
        delta.data += alpha*delta.grad.detach() / norms(delta.grad.detach())
#        delta.data = torch.min(torch.max(delta.detach(), -X), 1-X) # clip X+delta to [0,1]
        delta.data *= epsilon / norms(delta.detach()).clamp(min=epsilon)
        delta.grad.zero_()
        
    return delta.detach()
    
    
Seed = args.seed
MU = args.mu
Rate = args.rate
N = args.N
dim = args.dim

num1 = args.num1
num2 = args.num2

X = np.concatenate((x_train[y_train==num1],x_train[y_train==num2]))
y = np.concatenate((np.array([-1]*len(x_train[y_train==num1])),np.array([1]*len(x_train[y_train==num2]))))
n = len(y)
idx = np.arange(n)
np.random.shuffle(idx)
X_train = X[idx]
Y_train = y[idx]

X_test = np.concatenate((x_test[y_test==num1], x_test[y_test==num2]))
Y_test = np.concatenate((np.array([-1]*len(x_test[y_test==num1])),np.array([1]*len(x_test[y_test==num2]))))


for N_train in [N]:
    for size in [5,7,10,14,17,20,22,25,27,28,30,32,34,36,38,40]:
        for seed in [Seed]:
            np.random.seed(seed)
            random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            N_test = 2000
            beta = 0.1
            for mu in [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0]:

                T = 1000
                attack_iters = 20
                lr = 0.1

                X_train_tensor = torch.from_numpy(X_train[:N_train,:]).float().to(device)
                Y_train_tensor = torch.from_numpy(Y_train[:N_train]).long().to(device)
                X_test_tensor = torch.from_numpy(X_test).float().to(device)
                Y_test_tensor = torch.from_numpy(Y_test).long().to(device)
                
                Transform = transforms.Resize((size, size))
                X_train_tensor = Transform(X_train_tensor.reshape(len(X_train_tensor),1,28,28)).reshape(-1, size*size)
                X_test_tensor = Transform(X_test_tensor.reshape(len(X_test_tensor),1,28,28)).reshape(-1, size*size)
                # normalize the data
                X_train_tensor = X_train_tensor / torch.norm(X_train_tensor, dim=1)[:,None] * mu
                X_test_tensor = X_test_tensor / torch.norm(X_test_tensor, dim=1)[:,None] * mu

                model = net(d=size*size, width=1000).to(device)
                criterion = Logistic_Loss()
                optimizer = optim.SGD(model.parameters(), lr = lr)



                for rate in [Rate]:
                    print('N', N_train, 'mu', mu, 'size', size, 'eps', rate, 'seed', seed)
                    trainloss, robusttrainacc, cleantestacc, robusttestacc, robusttestopt = [], [], [], [], []
                    name = 'mu'+str(mu)+'trainN'+str(N_train)+'size'+str(size)+'rate'+str(rate)+'T'+str(T)+'lr'+str(lr)+'beta'+str(beta)+'seed'+str(seed)
                    eps = rate * mu
                    alpha = eps / attack_iters * 4
                    for t in range(T):
                        # generate adversarial examples at each step
                        delta = pgd_l2(model, X_train_tensor, Y_train_tensor, criterion, eps, alpha, attack_iters)
                        # pret_X_train_tensor = X_train_tensor + delta
                        X, y = X_train_tensor + delta, Y_train_tensor

                        if t > T/2:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr / 10
                        
                        optimizer.zero_grad()
                        pred = model(X)
                        loss = criterion(pred, y)
                        loss.backward()
                        optimizer.step()
                        robustacc=np.sum(((pred.sign()).squeeze(-1)==y).detach().cpu().numpy()) / len(y)
                        
                        clean_accuracy = np.sum(((model(X_test_tensor).sign()).squeeze(-1)==Y_test_tensor).detach().cpu().numpy()) / len(Y_test_tensor)
                        
                        delta = pgd_l2(model, X_test_tensor, Y_test_tensor, criterion, eps, alpha, attack_iters)
                        
                        robust_accuracy = np.sum(((model(X_test_tensor+delta).sign()).squeeze(-1)==Y_test_tensor).detach().cpu().numpy()) / len(Y_test_tensor)
                        
                        trainloss.append(loss.item())
                        cleantestacc.append(clean_accuracy)
                        robusttrainacc.append(np.mean(robustacc))
                        robusttestacc.append(robust_accuracy)
                        
#                        print('t:', t, 'loss:', trainloss[-1], 'robust train acc:', robusttrainacc[-1], 'clean test acc:', clean_accuracy, 'robust test acc:', robust_accuracy)
                    f = open('log_mnist/'+name+'.pkl', 'wb')
                    pickle.dump((trainloss, robusttrainacc, cleantestacc, robusttestacc),f)
                    f.close()
