import torch
import torch.nn as nn
from tqdm import tqdm

from utils import load_data
from models import ResNet18

class L2PGDReprAttack(object):
    def __init__(self, model, num_steps=10, epsilon=1.0):
        self.model = model
        self.num_steps = num_steps
        self.epsilon = epsilon
        self.alpha = epsilon/4

    def perturb(self, x, y):
        x_repr = self.model.calc_representation(x)
        B = x.shape[0]
        adv_x = x.detach()
        adv_x = adv_x + torch.zeros_like(adv_x).uniform_(-self.epsilon, self.epsilon)
        delta = adv_x - x
        delta_norm = delta.view(B, -1).norm(dim=1).view(-1,1,1,1)
        clamp_delta = delta / delta_norm * torch.clamp(delta_norm,0,self.epsilon)
        adv_x = x + clamp_delta
        adv_x = torch.clamp(adv_x,0,1)

        for i in range(self.num_steps):
            adv_x.requires_grad_()
            with torch.enable_grad():
                adv_repr = self.model.calc_representation(adv_x)
                loss = nn.functional.mse_loss(x_repr, adv_repr)
                #print ('\t',loss)
                #logits = self.model(adv_x)
                #loss = nn.functional.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [adv_x])[0]
            grad_norm = grad.view(B, -1).norm(dim=1).view(-1,1,1,1)
            grad = grad / grad_norm
            adv_x = adv_x.detach() + self.alpha * grad.detach()

            delta = adv_x - x
            delta_norm = delta.view(B, -1).norm(dim=1).view(-1,1,1,1)
            clamp_delta = delta / delta_norm * torch.clamp(delta_norm,0,self.epsilon)
            adv_x = x + clamp_delta
            adv_x = torch.clamp(adv_x,0,1)
        return adv_x.detach()

#METHOD = 'advTrepr'
METHOD = 'advTreprSSL'

attack = 'l2'
eps = 1.0
#eps = 0.1
#eps=0.02

#attack = 'linf'
##eps = 8./255
#eps = 32./255

trainset, testset, trainloader, testloader, normalizer = load_data()
print (len(trainset), len(testset))

model = ResNet18(normalizer)
model = model.to('cuda')
print (model)
if attack == 'linf':
    adversary = LinfPGDReprAttack(model, epsilon=eps)
elif attack == 'l2':
    adversary = L2PGDReprAttack(model, epsilon=eps)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()

            adv_x = adversary.perturb(x, y)
            if METHOD == 'advTrepr':
                pred = model(adv_x)
                loss = criterion(pred, y)
            elif METHOD == 'advTreprSSL':
                feature = model.calc_representation(x)
                adv_feature = model.calc_representation(adv_x)
                pred = model.linear(feature)
                #print (criterion(pred, y))
                #print (5.0*nn.functional.mse_loss(feature, adv_feature))
                loss = criterion(pred, y) + 5.0*nn.functional.mse_loss(feature, adv_feature)
            else:
                assert 0
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


#model.load_state_dict(torch.load('./saved_model/naive.pth'))
best_acc = 0.0
for epoch in range(200):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/%s-%s-%.4f.pth'%(METHOD, attack, eps))
