import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import copy

from torch.autograd import Variable
from model_search_imagenet_topo import Network
from architect import Architect

from genotypes import Genotype
import visualize

from copy import deepcopy
import codecs
import json
from numpy import linalg as LA
from analyze import Analyzer


parser = argparse.ArgumentParser("imagenet")
parser.add_argument('--workers', type=int, default=4, help='number of workers to load dataset')
parser.add_argument('--data', type=str, default='/tmp/cache/', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
# parser.add_argument('--batch_size', type=int, default=1400, help='batch size') # 3 2080Ti
# parser.add_argument('--batch_size', type=int, default=1000, help='batch size') # 1 V100
# parser.add_argument('--batch_size', type=int, default=300, help='batch size') # 1 2080Ti
parser.add_argument('--learning_rate', type=float, default=0.5, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.0, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--epochs', type=int, default=1, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='/export/data/lwangcg/PC-DARTS/ImagenetExps/', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=6e-3, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--begin', type=int, default=35, help='batch size')

# parser.add_argument('--tmp_data_dir', type=str, default='/cache/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/SampledTrainDataset/', help='temp data dir')
parser.add_argument('--tmp_data_dir', type=str, default='/dev/shm/SampledTrainDataset/', help='temp data dir')
parser.add_argument('--note', type=str, default='try', help='note for this run')

parser.add_argument('--report_freq_hessian', type=float, default=1, help='report frequency hessian')
parser.add_argument('--early_stop', type=int, default=1, choices=[0, 1, 2, 3], help='early stop DARTS based on dominant eigenvalue. 0: no 1: yes 2: simulate 3: adaptive regularization')
parser.add_argument('--window', type=int, default=5, help='window size of the local average')
parser.add_argument('--es_start_epoch', type=int, default=0, help='when to start considering early stopping')
parser.add_argument('--delta', type=int, default=4, help='number of previous local averages to consider in early stopping')
parser.add_argument('--factor', type=float, default=1.3, help='early stopping factor')
parser.add_argument('--extra_rollback_epochs', type=int, default=0, help='number of extra rollback epochs when deciding to increse regularization')
parser.add_argument('--compute_hessian', action='store_true', default=False, help='compute or not Hessian')
parser.add_argument('--max_weight_decay', type=float, default=0.03, help='maximum weight decay')
parser.add_argument('--mul_factor', type=float, default=10, help='multiplication factor')
parser.add_argument('--debug', action='store_true', default=False, help='use one-step unrolled validation loss')

args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 1024
  args.learning_rate_min = args.learning_rate_min * args.batch_size / 1024
  args.arch_learning_rate = args.arch_learning_rate * args.batch_size / 1024
  # args.epochs = int(args.epochs * args.batch_size / 1024)

adapt_batch_size()

args.save = '{}search-{}-{}'.format(args.save, args.note, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

logging.info('save file: ' + args.save)
tensorbard_logger.add_text('save file', args.save)

# data_dir = os.path.join(args.tmp_data_dir, 'imagenet_search')
data_dir= args.tmp_data_dir
 #data preparation, we random sample 10% and 2.5% from training set(each class) as train and val, respectively.
#Note that the data sampling can not use torch.utils.data.sampler.SubsetRandomSampler as imagenet is too large   
CLASSES = 1000

global_iter = 0


def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    #torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    #logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    tensorbard_logger.add_text('args', str(args))
    #dataset_dir = '/cache/'
    #pre.split_dataset(dataset_dir)
    #sys.exit(1)
   # dataset prepare
    global data_dir
    # traindir = data_dir = os.path.join(data_dir, 'train')
    traindir = os.path.join(data_dir, 'train')
    # valdir = data_dir = os.path.join(data_dir,  'val')
    valdir = os.path.join(data_dir, 'val')
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    #dataset split     
    train_data1 = dset.ImageFolder(traindir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_data2 = dset.ImageFolder(valdir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    num_train = len(train_data1)
    num_val = len(train_data2)
    print('# images to train network: %d' % num_train)
    print('# images to validate network: %d' % num_val)
    tensorbard_logger.add_text('# images to train network', str(num_train))
    tensorbard_logger.add_text('# images to validate network', str(num_val))
    
    model = Network(args.init_channels, CLASSES, args.layers, criterion)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
               lr=args.arch_learning_rate, betas=(0.5, 0.999), 
               weight_decay=args.arch_weight_decay)
    
    test_queue = torch.utils.data.DataLoader(
                        valid_data, 
                        batch_size=args.batch_size, 
                        shuffle=False, 
                        pin_memory=True, 
                        num_workers=args.workers)

    train_queue = torch.utils.data.DataLoader(
        train_data1, batch_size=args.batch_size, shuffle=True,
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data2, batch_size=args.batch_size, shuffle=True,
        pin_memory=True, num_workers=args.workers)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    analyser = Analyzer(args, model.module)

    la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                  args.epochs)

    errors_dict = {'train_acc': [], 'train_loss': [], 'valid_acc': [],
                   'valid_loss': []}

    #architect = Architect(model, args)
    lr=args.learning_rate
    total_time = 0
    for epoch in range(args.epochs):
        scheduler.step()
        current_lr = scheduler.get_lr()[0]
        logging.info('Epoch: %d lr: %e', epoch, current_lr)
        tensorbard_logger.add_scalar('train/lr', lr, epoch)
        # if epoch < 5 and args.batch_size > 256:
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr * (epoch + 1) / 5.0
        #     logging.info('Warming-up Epoch: %d, LR: %e', epoch, lr * (epoch + 1) / 5.0)
        #     print(optimizer)
        genotype = model.module.genotype()
        logging.info('genotype before updating = %s', genotype)
        visualize.plot(genotype.normal, "normal_before_updating", tensorboard_logger=tensorbard_logger, epoch=epoch)
        visualize.plot(genotype.reduce, "reduction_before_updating", tensorboard_logger=tensorbard_logger, epoch=epoch)
        arch_param = model.module.arch_parameters()
        logging.info(F.softmax(arch_param[0], dim=-1))
        logging.info(F.softmax(arch_param[1], dim=-1))
        # training
        train_acc, train_obj, total_time = train(train_queue, valid_queue, model, optimizer, optimizer_a, criterion, lr, epoch, analyser=analyser, local_avg_tracker=la_tracker, errors_dict=errors_dict, total_time=total_time)
        print('time cost', total_time, 's')
        print('Time End.')
        sys.exit()

        logging.info('Train_acc %f', train_acc)
        tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

        genotype = model.module.genotype()
        logging.info('genotype after updating = %s', genotype)
        visualize.plot(genotype.normal, "normal after updating", tensorboard_logger=tensorbard_logger, epoch=epoch)
        visualize.plot(genotype.reduce, "reduction after updating", tensorboard_logger=tensorbard_logger, epoch=epoch)


def train(train_queue, valid_queue, model, optimizer, optimizer_a, criterion, lr, epoch, analyser, local_avg_tracker, iteration=1, errors_dict=None, total_time=0):
    global global_iter
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()

    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)

        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # get a random minibatch from the search queue with replacement
        try:
            input_search, target_search = next(valid_queue_iter)
        except:
            valid_queue_iter = iter(valid_queue)
            input_search, target_search = next(valid_queue_iter)
        input_search = input_search.cuda(non_blocking=True)
        target_search = target_search.cuda(non_blocking=True)

        if total_time == 0:
            print('Time start.')
        time_start = time.time()
        # if epoch >=args.begin:
        optimizer_a.zero_grad()
        logits = model(input_search)
        loss_a = criterion(logits, target_search)
        loss_a.sum().backward()
        nn.utils.clip_grad_norm_(model.module.arch_parameters(), args.grad_clip)
        optimizer_a.step()
        #architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
        optimizer.step()

        time_end = time.time()
        total_time = total_time + time_end - time_start

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f', step, objs.avg, top1.avg, top5.avg)
            tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)

        global_iter = global_iter + 1

        # if args.compute_hessian:
        #     if (step % args.report_freq_hessian == 0):
        #         _data_loader = deepcopy(train_queue)
        #         input, target = next(iter(_data_loader))
        #
        #         input = Variable(input, requires_grad=False).cuda()
        #         target = Variable(target, requires_grad=False).cuda(non_blocking=True)
        #
        #         # get gradient information
        #         # param_grads = [p.grad for p in model.parameters() if p.grad is not None]
        #         # param_grads = torch.cat([x.view(-1) for x in param_grads])
        #         # param_grads = param_grads.cpu().data.numpy()
        #         # grad_norm = np.linalg.norm(param_grads)
        #
        #         # gradient_vector = torch.cat([x.view(-1) for x in gradient_vector])
        #         # grad_norm = LA.norm(gradient_vector.cpu())
        #         # logging.info('\nCurrent grad norm based on Train Dataset: %.4f',
        #         #             grad_norm)
        #
        #         if not args.debug:
        #             H = analyser.compute_Hw(input, target, input_search, target_search,
        #                                     lr, optimizer, False)
        #             g = analyser.compute_dw(input, target, input_search, target_search,
        #                                     lr, optimizer, False)
        #             g = torch.cat([x.view(-1) for x in g])
        #
        #             del _data_loader
        #
        #             state = {'epoch': step,
        #                      'H': H.cpu().data.numpy().tolist(),
        #                      'g': g.cpu().data.numpy().tolist(),
        #                      # 'g_train': float(grad_norm),
        #                      # 'eig_train': eigenvalue,
        #                      }
        #
        #             with codecs.open(os.path.join(args.save,
        #                                           'derivatives.json'),
        #                              'a', encoding='utf-8') as file:
        #                 json.dump(state, file, separators=(',', ':'))
        #                 file.write('\n')
        #
        #             # early stopping
        #             ev = max(LA.eigvals(H.cpu().data.numpy()))
        #         else:
        #             ev = 0.1
        #             if step >= 8 and iteration == 1:
        #                 ev = 2.0
        #         logging.info('CURRENT EV: %f', ev)
        #         local_avg_tracker.update(step, ev, model.module.genotype())
        #
        #         # if args.early_stop and step != (len(train_queue) - 1):
        #         #   local_avg_tracker.early_stop(step, args.factor, args.es_start_epoch,
        #         #                                args.delta)

        # # validation
        # if args.early_stop:
        #     valid_acc, valid_obj = infer(valid_queue, model, criterion)
        #     logging.info('valid_acc %f', valid_acc)
        #     tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, step)
        # else:
        #     if len(train_queue) - epoch <= 1 and args.epochs > 1:
        #         valid_acc, valid_obj = infer(valid_queue, model, criterion)
        #         logging.info('valid_acc %f', valid_acc)
        #         tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, step)
        #
        # # utils.save(model, os.path.join(args.save, 'weights.pt'))
        #
        # train_acc = top1.avg
        # train_obj = objs.avg
        #
        # errors_dict['train_acc'].append(100 - train_acc)
        # errors_dict['train_loss'].append(train_obj)
        # errors_dict['valid_acc'].append(100 - valid_acc)
        # errors_dict['valid_loss'].append(valid_obj)
        #
        # if local_avg_tracker.stop_search:
        #     # set the following to the values they had at stop_epoch
        #     errors_dict['valid_acc'] = errors_dict['valid_acc'][:local_avg_tracker.stop_epoch + 1]
        #     genotype = local_avg_tracker.stop_genotype
        #     valid_acc = 100 - errors_dict['valid_acc'][local_avg_tracker.stop_epoch]
        #     logging.info(
        #         'Decided to stop the search at step %d (Current step: %d)',
        #         local_avg_tracker.stop_epoch, step
        #     )
        #     logging.info(
        #         'Validation accuracy at stop epoch: %f', valid_acc
        #     )
        #     logging.info(
        #         'Genotype at stop epoch: %s', genotype
        #     )
        #     break
        #
        # # genotype = model.module.genotype()
        # # logging.info('genotype_after_iteration ' + str(step) + ': ' + str(genotype))
        # # visualize.plot(genotype.normal, "normal_after_iteration", tensorboard_logger=tensorbard_logger, epoch=step)
        # # visualize.plot(genotype.reduce, "reduction_after_iteration", tensorboard_logger=tensorbard_logger, epoch=step)
        # if args.compute_hessian:
        #     tensorbard_logger.add_scalar('max_eigen_value', abs(ev), step)

    return top1.avg, objs.avg, total_time



def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    for step, (input, target) in enumerate(valid_queue):
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        with torch.no_grad():
            logits = model(input)
            loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
            tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)

    return top1.avg, objs.avg


if __name__ == '__main__':
    main() 

