# train a CNN model on different datasets

import os
import copy
import argparse
import pickle

from tqdm import tqdm
import torch
import torchvision
from torchvision import transforms

from data import get_mnist, get_cifar
from mobilenet_v2 import MobileNetV2
from resnet import resnet34

def train_cnn():
    # parse all the arguments for the script
    parser = argparse.ArgumentParser(description='MLP Training')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--model-name', type=str, default='mobilenetv2')
    parser.add_argument('--use-lr-sched', action='store_true', default=False)
    parser.add_argument('--lr', type=float, default=1e-6, metavar='LR')
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--trn-size', type=int, default=50000)
    parser.add_argument('--validate', action='store_true', default=False)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--verbose', action='store_true', default=False)
    parser.add_argument('--check-hull', action='store_true', default=False)
    parser.add_argument('--downsample-method', type=str, default='uniform')
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight-decay', type=float, default=5e-4)
    parser.add_argument('--model-save-iters', type=int, default=-1)
    parser.add_argument('--exp-name', type=str, default=None)
    parser.add_argument('--dropout', type=float, default=0.2)
    args = parser.parse_args()

    # get the data
    if args.dataset == 'cifar10':
        
        if args.downsample_method == None:
            downsample_params = None
        else:
            downsample_params = {
                'size': args.trn_size,
                'method': args.downsample_method,
            }
        num_classes = 10
        trn_dl, test_dl = get_cifar(
                batch_size=args.batch_size, augment=True, binarize=False,
                downsample_params=downsample_params)
    else:
        raise NotImplementedError('Not a supported dataset')

    # check if there is a GPU available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # construct the model and optimizer
    if args.model_name == 'mobilenetv2':
        model = MobileNetV2(num_classes=10, width_mult=1., dropout=args.dropout, for_cifar=True)
    elif args.model_name == 'resnet':
        model = resnet34(pretrained=False, num_classes=10, in_chans=3, for_cifar=True)
    else:
        raise NotImplementedError(f'{args.model_name} is not a supported model')
    model = model.to(device)
    opt = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum,
            weight_decay=args.weight_decay)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = None
    if args.use_lr_sched:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)

    # save the model prior to any training
    if args.model_save_iters > -1:
        if not os.path.exists('./model_checkpoints/'):
            os.mkdir('./model_checkpoints/')
        fp = os.path.join('./model_checkpoints', args.exp_name + f'_model_iter0.pth')
        torch.save(model, fp)
        
    # main training loop
    iter_accs = []
    iter_losses = []
    trn_losses = []
    trn_accs = []
    test_accs = []
    test_losses = []
    agg_iter = 0
    for e in range(args.epochs):
        # training loop for single epoch
        if args.verbose:
            print(f'Running Epoch {e + 1} / {args.epochs}')
        model = model.to(device)
        agg_trn_loss = 0.
        num_corr = 0.
        num_ex = 0.
        for it, (model_in, target) in enumerate(trn_dl):
            model_in, target = model_in.to(device), target.to(device)
            opt.zero_grad()
            outputs = model(model_in)
            loss = criterion(outputs, target)
            loss.backward()
            opt.step()

            # compute metrics
            train_loss = loss.item()
            _, predicted = outputs.max(1)
            total = target.size(0)
            correct = predicted.eq(target).sum().item()

            # update running metrics
            agg_trn_loss += train_loss
            num_corr += correct
            num_ex += total
            iter_accs.append((correct / total))
            iter_losses.append(train_loss)

            # optionally save the model based on iterations
            agg_iter += 1
            if (agg_iter % args.model_save_iters) == 0:
                if not os.path.exists('./model_checkpoints/'):
                    os.mkdir('./model_checkpoints/')
                fp = os.path.join('./model_checkpoints', args.exp_name + f'_model_iter{agg_iter}.pth')
                torch.save(model, fp)

        
        # print out training metrics at each iteration
        agg_trn_loss = agg_trn_loss / len(trn_dl)
        agg_trn_acc = num_corr / num_ex
        trn_losses.append(agg_trn_loss)
        trn_accs.append(agg_trn_acc)
        if args.verbose:
            print(f'Training Loss: {agg_trn_loss:.4f}')
            print(f'Training Acc.: {agg_trn_acc:.4f}')

        # evaluate model performance on test set
        if args.validate:
            agg_test_loss = 0.
            num_ex = 0.
            num_corr = 0.
            with torch.no_grad():
                for model_in, target in test_dl:
                    model_in, target = model_in.to(device), target.to(device)
                    outputs = model(model_in)
                    loss = criterion(outputs, target)

                    # compute metrics
                    test_loss = loss.item()
                    _, predicted = outputs.max(1)
                    total = target.size(0)
                    correct = predicted.eq(target).sum().item()

                    # track running metrics
                    agg_test_loss += test_loss
                    num_ex += total
                    num_corr += correct

            test_accs.append((num_corr/num_ex))
            test_losses.append((agg_test_loss / len(test_dl)))
            if args.verbose:
                print(f'Test Loss: {test_losses[-1]:.4f}')
                print(f'Test Acc.: {test_accs[-1]:.4f}')

        # progress the learning rate scheduler
        if scheduler is not None:
            scheduler.step()

    # save the model after training
    if args.model_save_iters > -1:
        if not os.path.exists('./model_checkpoints/'):
            os.mkdir('./model_checkpoints/')
        fp = os.path.join('./model_checkpoints', args.exp_name + f'_model_final.pth')
        torch.save(model, fp)

    # save the final results of training
    if args.exp_name is not None:
        if not os.path.exists('./results/'):
            os.mkdir('./results/')
        fp = os.path.join('./results', args.exp_name + '.pckl')
        all_result = {
            'iter_accs': iter_accs,
            'iter_losses': iter_losses,
            'trn_loss': trn_losses,
            'trn_acc': trn_accs,
            'test_acc': test_accs,
        }
        with open(fp, 'wb') as f:
            pickle.dump(all_result, f)

if __name__=='__main__':
    train_cnn()
