import torch
import torch.nn as nn
from tqdm import tqdm

from utils import load_data
from models import ResNet18, ResNet50, ResNet152, CIFAR_CNN

#METHOD = 'alpha'
#METHOD = 'alpha-out'
#METHOD = 'alpha-true'
METHOD = 'alpha-approx'
#METHOD = 'alpha-approxrepr'
#LAMDA_ALPHA = 3.0
#LAMDA_ALPHA = 6.0
#LAMDA_ALPHA = 10.0
#LAMDA_ALPHA = 20.0
#LAMDA_ALPHA = 30.0
#LAMDA_ALPHA = 50.0
#LAMDA_ALPHA = 60.0
#LAMDA_ALPHA = 100.0
#LAMDA_ALPHA = 150.0
#LAMDA_ALPHA = 200.0
LAMDA_ALPHA = 300.0
#LAMDA_ALPHA = 600.0
#LAMDA_ALPHA = 1000.0
#LAMDA_ALPHA = 3000.0
#LAMDA_ALPHA = 6000.0
#LAMDA_ALPHA = 10000.0

#METHOD = 'alpha-last'
#LAMDA_ALPHA = 1e-3
##LAMDA_ALPHA = -0.3
#LAMDA_ALPHA = -0.1
##LAMDA_ALPHA = 0.03
##LAMDA_ALPHA = 0.1
##LAMDA_ALPHA = 0.3
##LAMDA_ALPHA = 1.0

#model_arch = 'cifarcnn'
model_arch = 'resnet18'

SAVE_NAME='./saved_model/%s-%.4f.pth'%(METHOD,LAMDA_ALPHA)
if model_arch != 'resnet18':
    SAVE_NAME = SAVE_NAME + '-%s'%model_arch
print (SAVE_NAME)

trainset, testset, trainloader, testloader, normalizer = load_data()
print (len(trainset), len(testset))

if model_arch == 'resnet18':
    model = ResNet18(normalizer, dropout=0.0)
elif model_arch == 'cifarcnn':
    model = CIFAR_CNN(normalizer, dropout=0.0)
model = model.to('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    train_loss_alpha = 0
    correct = 0
    total = 0
    with tqdm(trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            x.requires_grad_()

            #pred = model(x)
            features = model.calc_representation(x)
            pred = model.linear(features)
            loss = criterion(pred, y)

            ###
            if METHOD == 'alpha':
                tgt_val = features.norm(2, dim=1).mean()
                tgt_val.backward(create_graph=True)
                loss_alpha = x.grad.view(x.shape[0], -1).norm(2, dim=1).mean()
            elif METHOD == 'alpha-out':
                tgt_val = pred.norm(2, dim=1).mean()
                tgt_val.backward(create_graph=True)
                loss_alpha = x.grad.view(x.shape[0], -1).norm(2, dim=1).mean()
            elif METHOD == 'alpha-true':
                tgt_val = pred
                loss_alpha = 0.0
                jac = []
                for i in range(pred.shape[1]):
                    x.grad = None
                    pred[:,i].sum().backward(create_graph=True)
                    jac.append(x.grad)
                jac = torch.stack(jac, 1)
                loss_alpha = jac.view(jac.shape[0],-1).norm(2,dim=1).mean()
                #jac = torch.autograd.functional.jacobian(model.forward, x)
            elif METHOD == 'alpha-approx':
                rand_v = torch.FloatTensor(*pred.shape).normal_().to('cuda')
                rand_v = rand_v / rand_v.norm(2, dim=1,keepdim=True)
                tgt_val = (rand_v*pred).sum(1).mean()
                tgt_val.backward(create_graph=True)
                loss_alpha = x.grad.view(x.shape[0], -1).norm(2, dim=1).mean()
            elif METHOD == 'alpha-approxrepr':
                rand_v = torch.FloatTensor(*features.shape).normal_().to('cuda')
                rand_v = rand_v / rand_v.norm(2, dim=1,keepdim=True)
                tgt_val = (rand_v*features).sum(1).mean()
                tgt_val.backward(create_graph=True)
                loss_alpha = x.grad.view(x.shape[0], -1).norm(2, dim=1).mean()
            elif METHOD == 'alpha-last':
                loss_alpha = model.linear.weight.view(-1).norm()
                tgt_val = None
            else:
                assert 0
            loss = loss + LAMDA_ALPHA * loss_alpha
            ###

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loss_alpha += loss_alpha.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Loss_alpha: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), train_loss_alpha/(batch_idx+1), 100.*correct/total))
            del tgt_val
            del loss_alpha
            del x.grad
            torch.cuda.empty_cache()

    acc = 100.*correct/total
    return train_loss/len(trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


best_acc = 0.0
for epoch in range(200):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), SAVE_NAME)
