import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np

from utils import load_data, load_svhn_data
from models import ResNet18

TRANS_REPEAT = 3

#METHOD = 'alpha-approx'
#INIT_LAMDA=3.0
#FINAL_LAMDA=100.0

METHOD = 'alpha-approxrepr'
INIT_LAMDA=10.0
FINAL_LAMDA=300.0

#METHOD = 'alpha-last'
#INIT_LAMDA = 0.03
#FINAL_LAMDA = ?

SAVE_NAME='./saved_model/%s-tune-%s'%(METHOD,INIT_LAMDA)
SAVE_CNT = 0
print (SAVE_NAME)

trainset, testset, trainloader, testloader, normalizer = load_data()
print (len(trainset), len(testset))

_,_,_,transfer_loader,_ = load_svhn_data()
from itertools import cycle
transfer_loader = cycle(transfer_loader)

model = ResNet18(normalizer)
model = model.to('cuda')
model.load_state_dict(torch.load('saved_model/%s-%.4f.pth'%(METHOD,INIT_LAMDA)))
model.transfer_linear = nn.Linear(model.linear.in_features,10).to('cuda')
transfer_state_dict = torch.load('saved_model/%s-%.4f-transfer-svhn.pth'%(METHOD,INIT_LAMDA))
fc_state_dict = {k[7:]:v for k,v in transfer_state_dict.items() if 'linear' in k}
model.transfer_linear.load_state_dict(fc_state_dict)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
optimizer_transfer = torch.optim.SGD(model.transfer_linear.parameters(),lr=1e-3,momentum=0.9,weight_decay=5e-4)

def save_tune(model):
    global SAVE_CNT
    torch.save(model.state_dict(), SAVE_NAME+'_%d.pth'%SAVE_CNT)
    SAVE_CNT += 1

def train(epoch, cur_lamda):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    train_loss_alpha = 0
    correct = 0
    transfer_correct = 0
    total = 0
    with tqdm(trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            x.requires_grad_()

            features = model.calc_representation(x)
            pred = model.linear(features)
            loss = criterion(pred, y)

            ###
            if METHOD == 'alpha-approx':
                rand_v = torch.FloatTensor(*pred.shape).normal_().to('cuda')
                rand_v = rand_v / rand_v.norm(2, dim=1,keepdim=True)
                tgt_val = (rand_v*pred).sum(1).mean()
                tgt_val.backward(create_graph=True)
                loss_alpha = x.grad.view(x.shape[0], -1).norm(2, dim=1).mean()
            elif METHOD == 'alpha-approxrepr':
                rand_v = torch.FloatTensor(*features.shape).normal_().to('cuda')
                rand_v = rand_v / rand_v.norm(2, dim=1,keepdim=True)
                tgt_val = (rand_v*features).sum(1).mean()
                tgt_val.backward(create_graph=True)
                loss_alpha = x.grad.view(x.shape[0], -1).norm(2, dim=1).mean()
            elif METHOD == 'alpha-last':
                loss_alpha = model.linear.weight.view(-1).norm()
                tgt_val = None
            else:
                assert 0
            loss = loss + cur_lamda * loss_alpha
            ###

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loss_alpha += loss_alpha.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            del tgt_val
            del loss_alpha
            del x.grad
            torch.cuda.empty_cache()

            # transfer
            for _ in range(TRANS_REPEAT):
                x,y = next(transfer_loader)
                x, y = x.to('cuda'), y.to('cuda')
                features = model.calc_representation(x)
                pred = model.transfer_linear(features)
                loss = criterion(pred, y)
                optimizer_transfer.zero_grad()
                loss.backward()
                optimizer_transfer.step()
                pred_c = torch.argmax(pred,1)
                transfer_correct += ( pred_c.eq(y).sum().item()/len(y) )

            pbar.set_description('Loss: %.3f | Loss_alpha: %.3f | Acc:%.3f%% | Transfer acc:%.3f%%'%(train_loss/(batch_idx+1), train_loss_alpha/(batch_idx+1), 100.*correct/total, 100.*transfer_correct/(batch_idx+1)/TRANS_REPEAT))
            if (batch_idx+1) % 50 == 0:
                save_tune(model)

    acc = 100.*correct/total
    return train_loss/len(trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


print ("warm up epoch")
train(0, INIT_LAMDA)
SAVE_CNT = 0

best_acc = 0.0
for epoch in range(200):
    CUR_LAMDA = np.exp(   (np.log(FINAL_LAMDA)-np.log(INIT_LAMDA))*(epoch/199)+np.log(INIT_LAMDA)   )
    print ("current lambda:",CUR_LAMDA)
    train(epoch, CUR_LAMDA)
    test(epoch)
    save_tune(model)
