import numpy as np
import scipy as sp
from scipy import stats
import random
import math
import pickle
import argparse

import torch, torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable, Function


device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser(description='Gaussian Mixture data')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--mu', type=float, default=1.0, help='size of signal')
parser.add_argument('--N', type=int, default=100, help='number of training samples')
parser.add_argument('--dim', type=int, default=100, help='dimension')
args = parser.parse_args()

class net(nn.Module):
    def __init__(self, d=784, width=1000):
        super(net, self).__init__()
        self.fc1 = nn.Linear(d, width, bias = True)
        self.fc2 = nn.Linear(width, 1, bias = True)
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class Logistic_Loss(torch.nn.Module):
    def __init__(self):
        super(Logistic_Loss, self).__init__()

    def forward(self, inputs, target):
        L = torch.log(1 + torch.exp(-target*inputs.t()))
        return torch.mean(L)


def norms(Z):
    """Compute norms over all but the first dimension"""
    return Z.view(Z.shape[0], -1).norm(dim=1)[:,None]
def pgd_l2(model, X, y, criterion, epsilon, alpha, num_iter):
    delta = torch.zeros_like(X, requires_grad=True)
    for t in range(num_iter):
        loss = criterion(model(X + delta), y)
        loss.backward()
        delta.data += alpha*delta.grad.detach() / norms(delta.grad.detach())
        delta.data *= epsilon / norms(delta.detach()).clamp(min=epsilon)
        delta.grad.zero_()
        
    return delta.detach()
    
    
Seed = args.seed
#MU = args.mu
N = args.N
dim = args.dim

for N_train in [N]:
    for d in [dim]:
        for seed in [Seed]:
            np.random.seed(seed)
            random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            N_test = 2000
            beta = 0.1
            for mu in [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0]:
                mu_train = np.zeros(shape=(N_train, d))
                mu_train[:,0]=1
                mu_train = mu_train / np.linalg.norm(mu_train,axis=1)[0] * mu
                mu_test = np.zeros(shape=(N_test, d))
                mu_test[:,0]=1
                mu_test = mu_test / np.linalg.norm(mu_test,axis=1)[0] * mu

                yc_train = np.sign(np.random.normal(size = (N_train, 1)))
                x_train = mu_train * yc_train + np.random.normal(0, 1, size=(N_train, d))
                y_train = (stats.bernoulli.rvs(beta, size=(N_train, 1)) - 0.5) * (-2) * yc_train

                yc_test = np.sign(np.random.normal(size = (N_test, 1)))
                x_test = mu_test * yc_test + np.random.normal(0, 1, size=(N_test, d))
                y_test = (stats.bernoulli.rvs(beta, size=(N_test, 1)) - 0.5) * (-2) * yc_test

                T = 1000
                attack_iters = 20
                lr = 0.1

                X_train_tensor = torch.from_numpy(x_train).to(torch.float32).to(device)
                Y_train_tensor = torch.from_numpy(y_train).squeeze(-1).to(torch.float32).to(device)
                X_test_tensor = torch.from_numpy(x_test).to(torch.float32).to(device)
                Y_test_tensor = torch.from_numpy(y_test).squeeze(-1).to(torch.float32).to(device)

                model = net(d=d, width=1000).to(device)
                criterion = Logistic_Loss()
                optimizer = optim.SGD(model.parameters(), lr = lr)


                for rate in [0.1]:
                    print('N', N_train, 'mu', mu, 'd', d, 'eps', rate, 'seed', seed)
                    trainloss, robusttrainacc, cleantestacc, robusttestacc, robusttestopt = [], [], [], [], []
                    name = 'mu'+str(mu)+'trainN'+str(N_train)+'dim'+str(d)+'rate'+str(rate)+'T'+str(T)+'lr'+str(lr)+'beta'+str(beta)+'seed'+str(seed)
                    eps = rate * mu
                    alpha = eps / attack_iters * 4
                    for t in range(T):
                        # generate adversarial examples at each step
                        delta = pgd_l2(model, X_train_tensor, Y_train_tensor, criterion, eps, alpha, attack_iters)
                        # pret_X_train_tensor = X_train_tensor + delta
                        X, y = X_train_tensor + delta, Y_train_tensor

                        if t > T/2:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr / 10
                        
                        optimizer.zero_grad()
                        pred = model(X)
                        loss = criterion(pred, y)
                        loss.backward()
                        optimizer.step()
                        robustacc=np.sum(((pred.sign()).squeeze(-1)==y).detach().cpu().numpy()) / len(y)
                        
                        clean_accuracy = np.sum(((model(X_test_tensor).sign()).squeeze(-1)==Y_test_tensor).detach().cpu().numpy()) / len(Y_test_tensor)
                        
                        delta = pgd_l2(model, X_test_tensor, Y_test_tensor, criterion, eps, alpha, attack_iters)
                        delta_opt = torch.from_numpy(mu_test * rate * y_test * (-1)).to(torch.float32).to(device)
                        
                        robust_accuracy = np.sum(((model(X_test_tensor+delta).sign()).squeeze(-1)==Y_test_tensor).detach().cpu().numpy()) / len(Y_test_tensor)
                        robust_acc_opt = np.sum(((model(X_test_tensor+delta_opt).sign()).squeeze(-1)==Y_test_tensor).detach().cpu().numpy()) / len(Y_test_tensor)
                        
                        trainloss.append(loss.item())
                        cleantestacc.append(clean_accuracy)
                        robusttrainacc.append(np.mean(robustacc))
                        robusttestacc.append(robust_accuracy)
                        robusttestopt.append(robust_acc_opt)
                        
#                        print('t:', t, 'loss:', trainloss[-1], 'robust train acc:', robusttrainacc[-1], 'clean test acc:', clean_accuracy, 'robust test acc:', robust_accuracy)
                    f = open('log/'+name+'.pkl', 'wb')
                    pickle.dump((trainloss, robusttrainacc, cleantestacc, robusttestacc, robusttestopt),f)
                    f.close()
