from __future__ import print_function
from __future__ import absolute_import

import os
import sys
import time
import glob
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils

from torch.autograd import Variable
from model_search import Network
from architect import Architect
from config import get_policy_space, get_search_divider
from dataset import get_dataloaders, get_num_class, get_num_channel, get_label_name
from utils import reproducibility
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

parser = argparse.ArgumentParser("ada_aug")
parser.add_argument('--dataroot', type=str, default='ada_aug/data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10',
                    choices=['cifar10', 'reduced_cifar10', 'cifar100', 'reduced_cifar100',
                             'svhn', 'reduced_svhn', 'imagenet', 'reduced_imagenet', 'emnist', 'reduced_emnist',
                             'pet', 'car', 'flower', 'caltech', 'aircraft', 'reduced_pet', 'reduced_car', 
                             'reduced_flower', 'reduced_caltech', 'reduced_aircraft', 'cifar_svhn', 'reduced_cifar_svhn'],
                    help='name of dataset')
parser.add_argument('--batch_size', type=int, default=512, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.400, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=2e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=1, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=20, help='num of training epochs')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=1, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=1e-2, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--use_cuda', type=bool, default=True, help="use cuda default True")
parser.add_argument('--use_parallel', type=bool, default=False, help="use data parallel default False")
parser.add_argument('--model_name', type=str, default='wresnet40_2', help="model_name")
parser.add_argument('--num_workers', type=int, default=0, help="num_workers")
parser.add_argument('--k_ops', type=int, default=1, help="number of augmentation applied during training")
parser.add_argument('--temperature', type=float, default=0.1, help="temperature")
parser.add_argument('--search_latent', action='store_true', default=False, help="search in latent space")
parser.add_argument('--self_search', action='store_true', default=False, help="search within the train set")
parser.add_argument('--aug_mode', type=str, default='vector', help="augmentation model", choices=["vector", "projection", "cnn"])
parser.add_argument('--architect', type=str, default='iterative', help="architecture update method", choices=["darts", "iterative"])
parser.add_argument('--arch_freq', type=float, default=1, help='architecture update frequency')
parser.add_argument('--n_proj_layer', type=int, default=0, help="number of hidden layer in augmentation policy projection")

args = parser.parse_args()
debug = True if args.save == "debug" else False

args.save = 'search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
if debug:
    args.save = os.path.join('debug', args.save)
else:
    args.save = os.path.join('search', args.dataset, args.save)

writer = SummaryWriter(f'{args.save}/board')
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)


def print_genotype(geno):
    for i, sub_policy in enumerate(geno):
        logging.info(f'{i}, \t op: {sub_policy[0]} m: {sub_policy[1]} w: {sub_policy[2]}')


def save_genotype(geno):
    with open(os.path.join(args.save, 'policy.txt'), 'w') as f:
        for i, sub_policy in enumerate(geno):
            f.write(f"{sub_policy[0]},{sub_policy[1]},{sub_policy[2]}\n")


def save_policy(model):
    if args.aug_mode == 'vector':
        save_genotype(model.genotype())
    elif args.aug_mode == 'projection':
        utils.save(model.projection, os.path.join(args.save, 'projection_weights.pt'))
        class2label = get_label_name(args.dataset, args.dataroot)
        model.mix_augment.save_history(class2label)
    elif args.aug_mode == 'cnn':
        utils.save(model.aug_cnn, os.path.join(args.save, 'aug_cnn_weights.pt'))
        class2label = get_label_name(args.dataset, args.dataroot)
        model.mix_augment.save_history(class2label)


sub_policies = get_policy_space(args.dataset)


def main():
    start_time = time.time()
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(args.gpu)

    reproducibility(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    n_class = get_num_class(args.dataset)
    n_channel = get_num_channel(args.dataset)

    model = Network(
        args.model_name, n_class, n_channel, args.use_cuda, args.use_parallel,
        temperature=args.temperature, criterion=criterion, search=True,
        latent=args.search_latent, writer=writer)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    div = get_search_divider(args.model_name)

    train_queue, valid_queue, search_queue, test_queue = get_dataloaders(
        args.dataset, args.batch_size, args.num_workers,
        args.dataroot, args.cutout, args.cutout_length,
        split=args.train_portion, split_idx=0, target_lb=-1,
        search=True, search_epoch=args.epochs,
        self_search=args.self_search, search_divider=div)

    logging.info(f'Dataset: {args.dataset}')
    logging.info(f'  |total: {len(train_queue.dataset)}')
    logging.info(f'  |train: {len(train_queue)*args.batch_size}')
    logging.info(f'  |valid: {len(valid_queue)*args.batch_size}')
    logging.info(f'  |search: {len(search_queue)*div}')

    after_transforms = train_queue.dataset.after_transforms
    model.add_augment_agent(sub_policies, after_transforms, aug_mode=args.aug_mode,
                            search=True, k_ops=args.k_ops, sampling='max', save_dir=args.save, n_proj_layer=args.n_proj_layer)
    logging.info("Augmentation Params:")
    if args.aug_mode == 'vector':
        print_genotype(model.genotype())

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    for epoch in range(args.epochs):
        lr = scheduler.get_last_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        # training
        train_acc, train_obj = train(train_queue, search_queue, model, architect, criterion, optimizer, lr, epoch, args.arch_freq)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion, epoch)

        model.writer.add_scalars('loss', {
            'train': train_obj,
            'valid': valid_obj,
        }, epoch)
        model.writer.add_scalars('accuracy', {
            'train': train_acc,
            'valid': valid_acc,
        }, epoch)

        logging.info(f'train_acc {train_acc} valid_acc {valid_acc}')
        scheduler.step()

        if args.aug_mode == 'vector':
            print_genotype(model.genotype())

        utils.save(model, os.path.join(args.save, 'weights.pt'))
        save_policy(model)

        #  shuffle train valid split
        if epoch != args.epochs-1 and args.self_search:
            train_queue, valid_queue, search_queue, _ = get_dataloaders(
                args.dataset, args.batch_size, args.num_workers,
                args.dataroot, args.cutout, args.cutout_length,
                split=args.train_portion, split_idx=epoch+1, target_lb=-1,
                search=True, search_epoch=args.epochs, self_search=args.self_search)

    test_acc, test_obj = infer(test_queue, model, criterion, args.epochs)
    logging.info(f'test_acc {test_acc}')

    figure = model.mix_augment.plot_history()
    model.writer.add_figure('policy', figure)

    end_time = time.time()
    elapsed = end_time - start_time

    logging.info('elapsed time: %.3f Hours' % (elapsed / 3600.))
    logging.info(f'saved to: {args.save}')

# @profile
def train(train_queue, search_queue, model, architect, criterion, optimizer, lr, epoch, arch_freq):
    # set_target_gpu(1)
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()

    model.set_augmenting(True)

    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = target.size(0)
        global_step = epoch * len(train_queue) + step

        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        s_time = time.time()
        if step % arch_freq == 0:
            if architect.name == 'iterative':
                architect.step_iterative_unroll(search_queue)

                # input_search, target_search = next(iter(search_queue))
                # input_search = Variable(input_search, requires_grad=True)
                # target_search = Variable(target_search, requires_grad=False).cuda(non_blocking=True)
                # architect.step_iterative(input_search, target_search)

            elif architect.name == 'darts':
                input_search, target_search = next(iter(search_queue))
                input_search = Variable(input_search, requires_grad=True)
                target_search = Variable(target_search, requires_grad=False).cuda(non_blocking=True)

                architect.step_darts(input, target, input_search, target_search, lr, optimizer, unrolled=True)  # args.unrolled)
        search_time = time.time() - s_time
        s_time = time.time()
        optimizer.zero_grad()
        # training iteration
        model.set_search(False)
        if step == 0:
            model.writer.add_image('train', make_grid(input), epoch)
            logits = model(input, target, epoch)
        else:
            logits = model(input, target)  # for logging purpose
        loss = criterion(logits, target)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits.detach(), target.detach(), topk=(1, 3))
        objs.update(loss.detach().item(), n)
        top1.update(prec1.detach().item(), n)
        top5.update(prec5.detach().item(), n)

        if global_step % args.report_freq == 0:
            logging.info('  |train %03d %e %f %f | %.3f + %.3f sec', global_step, objs.avg, top1.avg, top5.avg, search_time, time.time()-s_time)

    return top1.avg, objs.avg


def infer(valid_queue, model, criterion, epoch):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()
    model.set_augmenting(False)
    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):

            # if step == 0:
            #     model.writer.add_image('valid', make_grid(input), epoch)

            input = Variable(input).cuda()
            target = Variable(target).cuda(non_blocking=True)

            logits = model(input)
            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 3))
            n = input.size(0)
            objs.update(loss.detach().item(), n)
            top1.update(prec1.detach().item(), n)
            top5.update(prec5.detach().item(), n)

    return top1.avg, objs.avg


if __name__ == '__main__':
    main()
