import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import os
import sys
import logging
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.autograd import grad

import utils
from darts.model_search import Network


parser = argparse.ArgumentParser("cifar")
parser.add_argument("--data", type=str, default="../../data",help="location of the data corpus")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--steps", type=int, default=100, help="steps to optimize architecture")
parser.add_argument("--report_freq", type=float,default=10, help="report frequency")
parser.add_argument("--gpu", type=int, default=0, help="gpu device id")
parser.add_argument("--init_channels", type=int,default=16, help="num of init channels")
parser.add_argument("--layers", type=int, default=9, help="total number of layers")
parser.add_argument("--cutout", action="store_true",default=False, help="use cutout")
parser.add_argument("--gumbel", action="store_true",default=False, help="use gumbel")
parser.add_argument("--adaptive", action="store_true",default=False, help="adaptive constraint")
parser.add_argument("--init_alphas", type=float,default=0, help="initial weights of braches")
parser.add_argument("--penalty", type=float, default=2, help="operation regularizer")
parser.add_argument("--constraint", type=float, default=100, help="constraint for regularizer")
parser.add_argument("--rand_label", action="store_true",default=False, help="use rand_labeld search")
parser.add_argument("--rand_data", action="store_true",default=False, help="use rand_data data")
parser.add_argument("--save", type=str, default="NASI",help="experiment name")
parser.add_argument("--seed", type=int, default=5, help="exp seed")
args = parser.parse_args()

args.save = "search/{}-S{}".format(
    args.save,
    str(args.seed)
)
utils.create_exp_dir(args.save, scripts_to_save=None)

log_format = "%(asctime)s %(message)s"
logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format=log_format,
    datefmt="%m/%d %I:%M:%S %p",
)
fh = logging.FileHandler(os.path.join(args.save, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

CIFAR_CLASSES = 10


def main():
    if not torch.cuda.is_available():
        logging.info("no gpu device available")
        sys.exit(1)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    cudnn.deterministic = True
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    logging.info("gpu device = %d" % args.gpu)
    logging.info("args = %s", args)

    model = Network(
        args.init_channels,CIFAR_CLASSES,args.layers,
        init_alphas=args.init_alphas, 
        gumbel=args.gumbel
    )
    
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    train_transform, _ = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(
        root=args.data, train=True, download=True, transform=train_transform
    )

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=2,
    )

    arch = search(train_queue, model)
    logging.info("Final architecture: %s", (arch,))
    

def search(train_queue, model):
    scores = [torch.zeros(a.size()).type(a.type()) for a in model._arch_parameters]
    max_norm = 0
    
    sum_constraint = args.constraint
    
    model.train()
    model_params = [p for n, p in model.named_parameters() if p is not None]
    
    for step, (input, targets) in enumerate(train_queue):
        model.reset_zero_grads()
        input = input.cuda()
        if args.rand_data:
            input.normal_()
        targets = targets.cuda()
        if args.rand_label:
            idx = torch.randperm(targets.numel())
            targets = targets[idx]
        logits = model(input)
        
        task_loss = F.cross_entropy(logits, targets) 
        grads = grad(task_loss, model_params, create_graph=True)
        
        if args.adaptive:
            constraint = sum_constraint / (step+1)
        else:
            constraint = args.constraint
        
        loss, approx_trace = PenaltyLoss(grads, constraint)
        loss.backward()
        
        sum_constraint += approx_trace.item()
        
        grad_norm = model.arch_param_grad_norm()
        if max_norm == 0:
            max_norm = grad_norm.clone()
        torch.max(grad_norm, max_norm, out=max_norm)
        
        for i, alpha in enumerate(model._arch_parameters):
            if alpha.grad is not None:
                alpha.grad.div_(max_norm)
                scores[i] += alpha.grad
        
        if step % args.report_freq == 0:
            logging.info("Train %03d, penalty loss %f", step, loss)
            logging.info("Selected architecture: %s", (model.genotype(scores),))
            logging.info("Constraint: %f", constraint)
            for s in scores:
                print(s)
        
        if step == args.steps - 1:
            break
    
    arch = model.genotype(scores)
    return arch


def PenaltyLoss(grads, constraint):
    approx_trace = sum([(g ** 2).sum() for g in grads])
    loss = approx_trace - args.penalty * F.relu(approx_trace - constraint)
    return loss, approx_trace

if __name__ == "__main__":
    main()
