import torch
import torch.nn as nn
from tqdm import tqdm

from utils import load_data
from attack_lib import LinfPGDAttack, L2PGDAttack
from models import ResNet18

#attack = 'l2'
##eps = 1.0
##eps = 0.1
#eps=0.02

attack = 'linf'
eps = 8./255
#eps = 32./255

trainset, testset, trainloader, testloader, normalizer = load_data(complex_aug=True)
print (len(trainset), len(testset))

model = ResNet18(normalizer)
model = model.to('cuda')
if attack == 'linf':
    adversary = LinfPGDAttack(model, epsilon=eps)
elif attack == 'l2':
    adversary = L2PGDAttack(model, epsilon=eps)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()

            adv_x = adversary.perturb(x, y)
            pred = model(adv_x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


#model.load_state_dict(torch.load('./saved_model/naive.pth'))
best_acc = 0.0
for epoch in range(200):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/advTaug-%s-%.4f.pth'%(attack, eps))
