import os
import argparse
import time

import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn

from models import model_dict

from dataset.cifar100 import get_cifar100_dataloaders

from helper.util import adjust_learning_rate
from helper.loops import train, validate

from gem import GEM


def parse_option():

    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    # dataset
    parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset')
    
    # model
    parser.add_argument('--model', type=str, default='MobileNetV2',
                        choices=['resnet20', 'resnet32', 'wrn_40_2', 'vgg8', 'MobileNetV2', 'ShuffleV2'])
    # loss
    parser.add_argument('--loss', type=str, default='gem', choices=['gem'], help='training loss')
    
    # gem
    parser.add_argument('-a', '--alpha', type=float, default=0.005, help='weight balance for second moment')
    parser.add_argument('-b', '--beta', type=float, default=0.05, help='weight balance for squared mean')

    parser.add_argument('-t', '--trial', type=int, default=0, help='the experiment id')

    opt = parser.parse_args()
    
    if opt.model in ['MobileNetV2', 'ShuffleV2']:
        opt.learning_rate = 0.01

    opt.model_path = './save/model'
    opt.log_pth = './save/log'

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_a_{}_b_{}_trial_{}'.format(opt.model, opt.loss, opt.alpha, opt.beta, opt.trial)

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    
    opt.log_key = '{}_{}'.format(opt.model, opt.loss)

    opt.log_folder = os.path.join(opt.log_pth, opt.log_key)
    if not os.path.isdir(opt.log_folder):
        os.makedirs(opt.log_folder)

    return opt


def main():
    best_acc = 0

    opt = parse_option()

    # dataloader
    if opt.dataset == 'cifar100':
        train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
        n_cls = 100
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = model_dict[opt.model](num_classes=n_cls)

    # optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    
    if opt.loss == 'gem':
        criterion = GEM(opt.alpha, opt.beta)
    else:
        raise NotImplementedError(opt.loss)

    if torch.cuda.is_available():
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # log file
    log_fname = os.path.join(opt.log_folder, '{experiment}.txt'.format(experiment=opt.model_name))

    with open(log_fname, 'a') as log:
        log.write("test_acc"+'\t'+ \
                  "test_ce"+'\t'+ \
                  "train_acc"+'\t'+ \
                  "train_ce"+'\n')
            
    # routine
    for epoch in range(1, opt.epochs + 1):

        adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        # the printed loss is always ce, not the total loss
        train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        # the printed loss is always ce, not the total loss
        test_acc, _, test_loss = validate(val_loader, model, criterion, opt)
        
        with open(log_fname, 'a') as log:
            log.write(str(test_acc.item())+'\t'+ \
                      str(test_loss)+'\t'+ \
                      str(train_acc.item())+'\t'+ \
                      str(train_loss)+'\n')

        # save the best model
        if test_acc > best_acc:
            best_acc = test_acc
            state = {
                'epoch': epoch,
                'model': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }
            save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model))
            print('saving the best model!')
            torch.save(state, save_file)

    print('best accuracy:', best_acc)
    with open(log_fname, 'a') as log:
        log.write('best accuracy: '+str(best_acc.item())+'\n')
    
    # save model
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)


if __name__ == '__main__':
    main()