import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import json
import os
import sys
import logging
import argparse

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torchvision.datasets as dset
from torch.autograd import grad

from nasbench_1shot1 import eval_nasi as naseval
from nasbench_1shot1.utils import NasbenchWrapper
from nasbench_1shot1.search_spaces.search_space_1 import SearchSpace1
from nasbench_1shot1.search_spaces.search_space_2 import SearchSpace2
from nasbench_1shot1.search_spaces.search_space_3 import SearchSpace3
from nasbench_1shot1.model_search import Network
import utils

parser = argparse.ArgumentParser("cifar")
parser.add_argument("--data", type=str, default="../../data",help="location of the data corpus")
parser.add_argument("--nasbench_data", type=str, 
                    default="nasbench_full.tfrecord",
                    help="location of the nasbench data")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--steps", type=int, default=800, help="steps to optimize architecture")
parser.add_argument("--report_freq", type=float,default=10, help="report frequency")
parser.add_argument("--gpu", type=int, default=0, help="gpu device id")
parser.add_argument("--init_channels", type=int,default=16, help="num of init channels")
parser.add_argument("--init_alphas", type=float,default=0, help="initial weights of braches")
parser.add_argument("--search_space", choices=["1", "2", "3"], default="1")
parser.add_argument("--layers", type=int, default=9, help="total number of layers")
parser.add_argument("--cutout", action="store_true",default=False, help="use cutout")
parser.add_argument("--gumbel", action="store_true",default=False, help="use gumbel")
parser.add_argument("--adaptive", action="store_true",default=False, help="adaptive constraint")
parser.add_argument('--output_weights', type=bool, default=True, help='Whether to use weights on the output nodes')
parser.add_argument("--penalty", type=float, default=1,help="operation regularizer")
parser.add_argument("--constraint", type=float, default=1000, help="constraint for regularizer")
parser.add_argument("--rand_label", action="store_true",default=False, help="use rand_labeld search")
parser.add_argument("--rand_data", action="store_true",default=False, help="use rand_data data")
parser.add_argument("--save", type=str, default="nasbench",help="experiment name")
parser.add_argument("--seed", type=int, default=0, help="random seed")
args = parser.parse_args()

args.save = "search/{}-S{}".format(
    args.save,
    str(args.seed),
)
utils.create_exp_dir(args.save, scripts_to_save=None)

log_format = "%(asctime)s %(message)s"
logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format=log_format,
    datefmt="%m/%d %I:%M:%S %p",
)
fh = logging.FileHandler(os.path.join(args.save, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

CIFAR_CLASSES = 10

# Dump the config of the run
with open(os.path.join(args.save, 'config.json'), 'w') as fp:
    json.dump(args.__dict__, fp)


def main():
    # Select the search space to search in
    if args.search_space == '1':
        search_space = SearchSpace1()
    elif args.search_space == '2':
        search_space = SearchSpace2()
    elif args.search_space == '3':
        search_space = SearchSpace3()
    else:
        raise ValueError('Unknown search space')

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    cudnn.deterministic = True
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    model = Network(
        args.init_channels,
        CIFAR_CLASSES,
        args.layers,
        output_weights=args.output_weights,
        steps=search_space.num_intermediate_nodes,
        search_space=search_space,
        init_alphas=args.init_alphas,
        gumbel=args.gumbel,
    )
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    train_transform, _ = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=2,
    )
    
    logging.info('loading data from nasbench')
    nasbench = NasbenchWrapper(dataset_file=args.nasbench_data)

    arch = search(train_queue, model, nasbench=nasbench)
    logging.info('final evaluation')
    real_time_eval(arch, nasbench=nasbench)


def search(train_queue, model, nasbench=None):
    scores = [0 for _ in model._arch_parameters]
    max_norm = 0
    
    sum_constraint = args.constraint
    
    model.train()
    model.reset_arch_trainable(train=True)
    model_params = [p for n, p in model.named_parameters() if p is not None] 

    for step, (input, targets) in enumerate(train_queue):
        model.reset_zero_grads()
        input = input.cuda()
        if args.rand_data:
            input.normal_()
        targets = targets.cuda()
        if args.rand_label:
            idx = torch.randperm(targets.numel())
            targets = targets[idx]
        logits = model(input)
        
        task_loss = F.cross_entropy(logits, targets)
        grads = grad(task_loss, model_params, create_graph=True)
        
        if args.adaptive:
            constraint = sum_constraint / (step+1)
        else:
            constraint = args.constraint
            
        loss, approx_trace = PenaltyLoss(grads, constraint)
        loss.backward()
        
        sum_constraint += approx_trace.item()
        
        grad_norm = model.arch_param_grad_norm()
        if max_norm == 0:
            max_norm = grad_norm
        torch.max(grad_norm, max_norm, out=max_norm)
        
        for i, alpha in enumerate(model._arch_parameters):
            if alpha.grad is not None:
                alpha.grad.div_(max_norm)
                scores[i] += alpha.grad.data.cpu().numpy()

        if step % args.report_freq == 0:
            logging.info("Train %03d, penalty loss %f", step , loss)
            logging.info("Constraint: %f", constraint)
            real_time_eval(
                [s if s is not None else None for s in scores], 
                nasbench=nasbench
            )
        
        if step == args.steps - 1:
            break
        
    arch = scores
    return arch


def PenaltyLoss(grads, constraint):
    approx_trace = sum([(g ** 2).sum() for g in grads])
    loss = approx_trace - args.penalty * F.relu(approx_trace - constraint)
    return loss, approx_trace


def real_time_eval(arch, nasbench):
    test, valid, runtime, params = naseval.eval_one_shot_model(
        config=args.__dict__,
        model_list=arch,
        nasbench=nasbench
    )
    logging.info(
        'TEST ERROR: %.3f, %.3f, %.3f | VALID ERROR: %.3f, %.3f, %.3f | RUNTIME: %f, %f, %f | PARAMS: %d, %d, %d' %
        (*test, *valid, *runtime, *params)
    )

if __name__ == '__main__':
    main()