import argparse
import os

import torch
import torch.nn as nn
import torch.optim as optim

from vgg import VGG
from dataset import CIFAR10, SVHN

def train_cls(trainloader, testloader, epoches, save_dir, name=None, weight_decay=5e-4):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net = VGG('VGG19')
    if args.load_from:
        net.load_state_dict(torch.load(args.load_from))
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=weight_decay)

    for epoch in range(epoches):  

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            x = 100
            if i % x == x-1:
                print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, running_loss / x), end='\r')
                running_loss = 0.0

        if epoch % 5 == 0:
            correct = 0
            total = 0
            with torch.no_grad():
                net.eval()
                for data in testloader:
                    images, labels = data[0].to(device), data[1].to(device)
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                net.train()

            print('Epoch:',epoch)
            print('Accuracy of the network on the 10000 test images: %d %%' % (
                100 * correct / total))

    print('Finished Training')

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    torch.save(net.state_dict(), os.path.join(save_dir, name if name is not None else 'final_model.pt'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Classifier training")
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--load_from", type=str)
    parser.add_argument("--save_to", type=str, default='models/cifar')
    parser.add_argument("--epoches", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--weight_decay", type=float, default=5e-4)

    args = parser.parse_args()

    if args.dataset == 'cifar10':
        dataset = CIFAR10()
        dataset.split()
    elif args.dataset == 'svhn':
        dataset = SVHN()
        dataset.split()
    split_loaders, testloader = dataset.get_dataloaders(batch_size=args.batch_size, split=True)
    for i, trainloader in enumerate(split_loaders):
        train_cls(trainloader, testloader, epoches=args.epoches, save_dir=args.save_to, name='model_'+str(i)+'.pt', weight_decay=args.weight_decay)

    