import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
import torchvision
from torchvision import transforms
import argparse
from models.resnet import *
from models.ViT import ViT


### Simple Attacks
def norms(Z):
    """Compute norms over all but the first dimension"""
    return Z.view(Z.shape[0], -1).norm(dim=1)[:, None, None, None]

def pgd_whitebox_l2(model, X, y, epsilon, num_steps, step_size, stats=None):
    def normalize_data(x):
        if stats:
            return transforms.Normalize(*stats)(x)
        else:
            return x

    X_pgd = Variable(X.data, requires_grad=True)
    for _ in range(num_steps):
        opt = optim.SGD([X_pgd], lr=1e-3)
        opt.zero_grad()

        with torch.enable_grad():
            loss = nn.CrossEntropyLoss()(model(normalize_data(X_pgd)), y)
        loss.backward()
        #
        eta = step_size * X_pgd.grad.detach() / norms(X_pgd.grad.detach())
        X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
        eta = X_pgd.data - X.data
        eta *= epsilon / norms(eta).clamp(min=epsilon)
        X_pgd = Variable(X.data + eta, requires_grad=True)
        X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
    return X_pgd


def pgd_whitebox_linf(model, X, y, epsilon, num_steps, step_size, stats=None):
    def normalize_data(x):
        if stats:
            return torchvision.transforms.Normalize(*stats)(x)
        else:
            return x

    X_pgd = Variable(X.data, requires_grad=True)
    for _ in range(num_steps):
        opt = optim.SGD([X_pgd], lr=1e-3)
        opt.zero_grad()
        with torch.enable_grad():
            loss = nn.CrossEntropyLoss()(model(normalize_data(X_pgd)), y)
        loss.backward()
        eta = step_size * X_pgd.grad.data.sign()
        X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
        eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
        X_pgd = Variable(X.data + eta, requires_grad=True)
        X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
    return X_pgd

parser = argparse.ArgumentParser(description='PAG Research argparser')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float)
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
parser.add_argument('--seed', type=int, default=7, metavar='S', help='random seed (default: 7)')
parser.add_argument('--arch', type=str, default='rn18', metavar='S', help='rn18, vit')
parser.add_argument('--attack', type=str, default='L2', help='L2 or Linf')
args, unknown = parser.parse_known_args()

print(args)

torch.manual_seed(args.seed)

attack_batch = pgd_whitebox_linf
epsilon = 8 / 255
if args.attack == 'L2':
    attack_batch = pgd_whitebox_l2
    epsilon = 0.5

# transforms
transform_train = torchvision.transforms.Compose([transforms.RandomCrop(32, padding=4),
                                                 transforms.RandomHorizontalFlip(),
                                                 transforms.ToTensor()])
transform_test = torchvision.transforms.Compose([transforms.ToTensor()])

dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True,
                                         drop_last=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)

# get model
if args.arch == 'vit':
    model = torch.nn.DataParallel(ViT()).cuda()
else:
    model = torch.nn.DataParallel(ResNet18()).cuda()
#
print(f'Using {args.arch} with {sum(p.numel() for p in model.parameters() if p.requires_grad)} learnable params')

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)


stats = ([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])

# train function
def train_func(model, train_loader, optimizer):
    model.train()
    #
    for batch_idx, (data, target) in enumerate(train_loader):
        images, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        X_pgd = attack_batch(model=model,
                             X=images,
                             y=target,
                             epsilon=epsilon,
                             num_steps=7,
                             step_size=1.5 * epsilon / 7,
                             stats=stats)
        pred = model(X_pgd)
        # CE loss
        loss = torch.nn.CrossEntropyLoss()(pred, target)
        if batch_idx % 100 == 0:
            print(f'Train loss in batch {batch_idx}: {loss}')
        loss.backward()
        optimizer.step()


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    epochs = args.epochs
    if epoch >= 0.5 * epochs:
        lr = args.lr * 0.1
    if epoch >= 0.75 * epochs:
        lr = args.lr * 0.01
    if epoch >= epochs:
        lr = args.lr * 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


best_acc = -1
for epoch in range(1, args.epochs + 1):
    # adjust learning rate for SGD
    adjust_learning_rate(optimizer, epoch)
    # training
    print(f"Training epoch {epoch}")
    train_func(model, dataloader, optimizer)
torch.save(model.state_dict(), f'checkpoints/AT-attack-{args.attack}-{args.arch}-lr-{args.lr}.pt')