import os
import sys
import time
import glob
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils

from torch.autograd import Variable
from model_search import Network
from dataset import get_num_class, get_num_channel, get_dataloaders, get_label_name
from utils import reproducibility, parse_genotype, load_augmentor
from config import get_policy_space, get_warmup_config
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from warmup_scheduler import GradualWarmupScheduler


parser = argparse.ArgumentParser("ada_aug")

parser.add_argument('--dataroot', type=str,
                    default='./ada_aug/data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10',
                    choices=['mnist', 'reduced_mnist', 'reduced_mnist018', 'reduced_mnist369',
                             'cifar10', 'reduced_cifar10', 'cifar100', 'reduced_cifar100',
                             'svhn', 'reduced_svhn', 'imagenet', 'reduced_imagenet', 'emnist', 'reduced_emnist',
                             'pet', 'car', 'flower', 'caltech', 'aircraft', 'reduced_pet', 'reduced_car', 
                             'reduced_flower', 'reduced_caltech', 'reduced_aircraft', 'cifar_svhn', 'reduced_cifar_svhn'],
                    help='name of dataset')
parser.add_argument('--train_portion', type=float,
                    default=0.5, help='portion of training data')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--num_workers', type=int, default=0, help="num_workers")

parser.add_argument('--learning_rate', type=float,
                    default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float,
                    default=0.0001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float,
                    default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float,
                    default=5, help='gradient clipping')

parser.add_argument('--use_cuda', type=bool, default=True,
                    help="use cuda default True")
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--use_parallel', type=bool, default=False,
                    help="use data parallel default False")
parser.add_argument('--init_channels', type=int,
                    default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20,
                    help='total number of layers')
parser.add_argument('--model_name', type=str,
                    default='wresnet40_2', help="model_name")
parser.add_argument('--model_path', type=str,
                    default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true',
                    default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float,
                    default=0.4, help='weight for auxiliary loss')
parser.add_argument('--temperature', type=float,
                    default=1.0, help="temperature")

parser.add_argument('--cutout', action='store_true',
                    default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int,
                    default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float,
                    default=0.2, help='drop path probability')

parser.add_argument('--epochs', type=int, default=600,
                    help='num of training epochs')
parser.add_argument('--report_freq', type=float,
                    default=50, help='report frequency')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--add_aug', action='store_true',
                    default=False, help='use add augment')

parser.add_argument('--policy_path', type=str, default='./', help='policy path')
parser.add_argument('--aug_mode', type=str, default='vector', help="augmentation model", choices=["vector", "projection", "cnn", "randaug", "dada", "fa", "aa", "pba"])
parser.add_argument('--k_ops', type=int, default=1, help="number of augmentation applied during training")
parser.add_argument('--restore_path', type=str, default='./', help='model path')
parser.add_argument('--restore', action='store_true', default=False, help='restore model')
parser.add_argument('--n_proj_layer', type=int, default=0, help="number of hidden layer in augmentation policy projection")
# parser.add_argument('--aug_n_class', type=int, default=0, help="number of class augmentor is trained on")

args = parser.parse_args()
debug = True if args.save == "debug" else False

args.save = '{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
if debug:
    args.save = os.path.join('debug', args.save)
else:
    args.save = os.path.join('eval', args.dataset, args.save)
writer = SummaryWriter(f'{args.save}/board')
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)


sub_policies = get_policy_space(args.dataset)


def print_genotype(geno):
    for i, sub_policy in enumerate(geno):
        logging.info(f'{i}, \t op: {sub_policy[0]} m: {sub_policy[1]} w: {sub_policy[2]}')


def model_debut():
    print('debug')


def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(args.gpu)
    reproducibility(args.seed)

    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    n_class = get_num_class(args.dataset)
    n_channel = get_num_channel(args.dataset)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    model = Network(
        args.model_name, n_class, n_channel, args.use_cuda,
        args.use_parallel, temperature=1.0, criterion=criterion)
    model = model.cuda()
    model.set_augmenting(False)  # Then the forward function will not do augmentation

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    # if use add_aug, the image returned by dataloader should be a list of un-after-transformed images
    train_queue, valid_queue, _, test_queue = get_dataloaders(
        args.dataset, args.batch_size, args.num_workers,
        args.dataroot, args.cutout, args.cutout_length,
        split=args.train_portion, split_idx=0, target_lb=-1,
        search=args.add_aug, search_epoch=1, self_search=False)

    if args.add_aug:
        after_transforms = train_queue.dataset.after_transforms
        # aug_n_class = n_class if args.aug_n_class == 0 else args.args_n_class
        # ugly way to get the model name and get the number of class the augmentor predicts
        aug_dset = args.policy_path.split('/')[2]
        aug_n_class = get_num_class(aug_dset)
        # aug_model_name = args.policy_path.split('/')[3].split('_')[2]
        aug_model_name = 'resnet50' if args.dataset in ['mnist', 'emnist', 'reduced_mnist', 'reduced_mnist'] else 'wresnet40_2'
        print(aug_dset)
        print(aug_model_name)
        augmentor = Network(
                aug_model_name, aug_n_class, n_channel, args.use_cuda,
                args.use_parallel, temperature=args.temperature,
                criterion=None, writer=writer)
        augmentor = augmentor.cuda()
        after_transforms = train_queue.dataset.after_transforms

        augmentor = load_augmentor(args, augmentor, after_transforms, n_class, sub_policies, aug_model_name)

    else:
        augmentor = None

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=True
        )

    logging.info(f'Dataset: {args.dataset}')
    logging.info(f'  |total: {len(train_queue.dataset)}')
    logging.info(f'  |train: {len(train_queue)*args.batch_size}')
    logging.info(f'  |valid: {len(valid_queue)*args.batch_size}')

    # if args.dataset in ['svhn', 'reduced_svhn']:
    #     scheduler = torch.optim.lr_scheduler.MultiStepLR(
    #         optimizer, milestones=[int(e) for e in '80,120,180'.split(',')], gamma=0.1)
    #     scheduler_name = 'multistep'
    # else:
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    scheduler_name = 'CosineAnnealing'

    if 'mnist' not in args.dataset:
        m, e = get_warmup_config(args.dataset, args.model_name)
        scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=m,
                total_epoch=e,
                after_scheduler=scheduler)
        logging.info(f'Scheduler: {scheduler_name} with WarmUp (m:{m}, e:{e})')

    if args.restore:
        trained_epoch = utils.restore_ckpt(model, optimizer, scheduler, args.restore_path, location=args.gpu) + 1
        n_epoch = args.epochs - trained_epoch
        logging.info(f'Restoring model from {args.restore_path}, starting from epoch {trained_epoch}')
    else:
        trained_epoch = 0
        n_epoch = args.epochs

    for i_epoch in range(n_epoch):
        epoch = trained_epoch + i_epoch
        lr = scheduler.get_last_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        train_acc, train_obj = train(
            train_queue, model, criterion, optimizer, epoch, augmentor)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj, _, _ = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        scheduler.step()

        utils.save_ckpt(model, optimizer, scheduler, epoch, os.path.join(args.save, 'weights.pt'))

        if args.add_aug and args.aug_mode in ['projection']:
            class2label = get_label_name(args.dataset, args.dataroot)
            augmentor.mix_augment.save_history(class2label)

        if epoch % args.report_freq == 0:
            test_acc, test_obj, test_acc5, _ = infer(test_queue, model, criterion)
            logging.info('test_acc %f %f', test_acc, test_acc5)

    if args.add_aug and args.aug_mode in ['projection', 'vector']:
        try:
            augmentor.mix_augment.plot_history()
        except Exception:
            pass
    test_acc, test_obj, test_acc5, _ = infer(test_queue, model, criterion)
    logging.info('test_acc %f %f', test_acc, test_acc5)
    logging.info(f'save to {args.save}')


def train(train_queue, model, criterion, optimizer, epoch, augmentor):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.train()
    model.set_augmenting(False)

    for step, (input, target) in enumerate(train_queue):
        # if add_aug, the cuda() is done by augment agent
        target = Variable(target).cuda(non_blocking=True)
        if step == 0:
            writer.add_image('train', make_grid(input, normalize=True), step)

        if not args.add_aug:
            input = input.cuda()
        else:
            if args.aug_mode != 'pba':
                input = augmentor.get_aug_images(input, target)
            else:
                input = augmentor.get_aug_images(input, epoch)

        if step == 0:
            writer.add_image('train_aug', make_grid(input, normalize=True), step)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 3))
        n = input.size(0)
        objs.update(loss.detach().item(), n)
        top1.update(prec1.detach().item(), n)
        top5.update(prec5.detach().item(), n)

        global_step = step + epoch * len(train_queue)

        if global_step % args.report_freq == 0:
            logging.info('train %03d %e %f %f', global_step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()
    model.set_augmenting(False)
    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = Variable(input).cuda()
            target = Variable(target).cuda(non_blocking=True)

            logits = model(input)
            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 3))
            n = input.size(0)
            objs.update(loss.detach().item(), n)
            top1.update(prec1.detach().item(), n)
            top5.update(prec5.detach().item(), n)

            # if step % args.report_freq == 0:
            #   logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg, top5.avg, objs.avg


if __name__ == '__main__':
    main()
