import torch
import torch.nn as nn
from tqdm import tqdm

from utils import load_data
from models import ResNet18, ResNet50, ResNet152, CIFAR_CNN

#WD = 1e-2
#WD = 8e-3
#WD = 5e-3
WD = 3e-3
#WD = 2e-3
#WD = 1e-3
#WD = 5e-4
#WD = 1e-4
#WD = 1e-5
#WD = 1e-6
#last_same = True
last_same = False

dropout = 0.0
#dropout = 0.2
#dropout = 0.5

l1_reg = 0.0
#l1_reg = 1e-3
#l1_reg = 1e-4
#l1_reg = 1e-5
#l1_reg = 1e-6

#model_arch = 'cifarcnn'
model_arch = 'resnet18'
#model_arch = 'resnet50'
#model_arch = 'resnet152'

SAVE_NAME = 'weightdecay-%s'%WD
if not last_same:
    SAVE_NAME = SAVE_NAME + '-lastdiff'
if dropout != 0.0:
    SAVE_NAME = SAVE_NAME + '-dropout%s'%dropout
if l1_reg != 0.0:
    SAVE_NAME = SAVE_NAME + '-l1reg%s'%l1_reg
if model_arch != 'resnet18':
    SAVE_NAME = SAVE_NAME + '-%s'%model_arch
print (SAVE_NAME)

trainset, testset, trainloader, testloader, normalizer = load_data()
print (len(trainset), len(testset))

if model_arch == 'resnet18':
    model = ResNet18(normalizer, dropout)
elif model_arch == 'resnet50':
    model = ResNet50(normalizer, dropout)
elif model_arch == 'resnet152':
    model = ResNet152(normalizer, dropout)
elif model_arch == 'cifarcnn':
    model = CIFAR_CNN(normalizer, dropout)
model = model.to('cuda')

criterion = nn.CrossEntropyLoss()
if last_same:
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=WD)
    print ("using smaller lr")
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=WD)
else:
    last_params = []
    other_params = []
    for name,p in model.named_parameters():
        if 'linear' in name:
            last_params.append(p)
        else:
            other_params.append(p)
    sgd_params = [{'params':last_params,'weight_decay':5e-4},{'params':other_params,'weight_decay':WD}]
    optimizer = torch.optim.SGD(sgd_params, lr=0.1, momentum=0.9)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)

            if l1_reg != 0.0:
                l1_loss = 0.0
                for p in model.parameters():
                    l1_loss += torch.norm(p.view(-1), 1)
                loss = loss + l1_reg * l1_loss

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


best_acc = 0.0
for epoch in range(200):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/%s.pth'%SAVE_NAME)
