import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import copy

from torch.autograd import Variable
from model_search_imagenet_topo import Network
from architect import Architect

from genotypes import Genotype
import visualize
import torch.cuda.amp as amp

from copy import deepcopy
import codecs
import json
from numpy import linalg as LA
from analyze import Analyzer


parser = argparse.ArgumentParser("imagenet")
# parser.add_argument('--workers', type=int, default=4, help='number of workers to load dataset')
parser.add_argument('--workers', type=int, default=72, help='number of workers to load dataset')
# parser.add_argument('--workers', type=int, default=32, help='number of workers to load dataset')
parser.add_argument('--data', type=str, default='/tmp/cache/', help='location of the data corpus')
# parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
# parser.add_argument('--batch_size', type=int, default=1500, help='batch size')
# parser.add_argument('--batch_size', type=int, default=1400, help='batch size') # 3 2080Ti
# parser.add_argument('--batch_size', type=int, default=1000, help='batch size') # 1 V100
parser.add_argument('--batch_size', type=int, default=1024, help='batch size') # 1 2080Ti
parser.add_argument('--learning_rate', type=float, default=0.5, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.0, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
# parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
parser.add_argument('--epochs', type=int, default=1, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='/export/data/lwangcg/PC-DARTS/ImagenetExps/', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=6e-3, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--begin', type=int, default=35, help='epoch start to optimize architecture')

# parser.add_argument('--tmp_data_dir', type=str, default='/cache/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/SampledTrainDataset/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/dev/shm/SampledTrainDataset/', help='temp data dir')
parser.add_argument('--tmp_data_dir', type=str, default='/dev/shm/ILSVRC2012/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/', help='temp data dir')
parser.add_argument('--note', type=str, default='try', help='note for this run')

parser.add_argument('--report_freq_hessian', type=float, default=1, help='report frequency hessian')
parser.add_argument('--early_stop', type=int, default=1, choices=[0, 1, 2, 3], help='early stop DARTS based on dominant eigenvalue. 0: no 1: yes 2: simulate 3: adaptive regularization')
parser.add_argument('--window', type=int, default=5, help='window size of the local average')
parser.add_argument('--es_start_epoch', type=int, default=0, help='when to start considering early stopping')
parser.add_argument('--delta', type=int, default=4, help='number of previous local averages to consider in early stopping')
parser.add_argument('--factor', type=float, default=1.3, help='early stopping factor')
parser.add_argument('--extra_rollback_epochs', type=int, default=0, help='number of extra rollback epochs when deciding to increse regularization')
parser.add_argument('--compute_hessian', action='store_true', default=False, help='compute or not Hessian')
parser.add_argument('--max_weight_decay', type=float, default=0.03, help='maximum weight decay')
parser.add_argument('--mul_factor', type=float, default=10, help='multiplication factor')
parser.add_argument('--debug', action='store_true', default=False, help='use one-step unrolled validation loss')

args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 1024
  args.learning_rate_min = args.learning_rate_min * args.batch_size / 1024
  args.arch_learning_rate = args.arch_learning_rate * args.batch_size / 1024
  # args.epochs = int(args.epochs * args.batch_size / 1024)

adapt_batch_size()

args.save = '{}search-{}-{}'.format(args.save, args.note, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

logging.info('save file: ' + args.save)
tensorbard_logger.add_text('save file', args.save)

# data_dir = os.path.join(args.tmp_data_dir, 'imagenet_search')
data_dir= args.tmp_data_dir
 #data preparation, we random sample 10% and 2.5% from training set(each class) as train and val, respectively.
#Note that the data sampling can not use torch.utils.data.sampler.SubsetRandomSampler as imagenet is too large   
CLASSES = 1000

global_iter = 0


def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    #torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    #logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    tensorbard_logger.add_text('args', str(args))
    #dataset_dir = '/cache/'
    #pre.split_dataset(dataset_dir)
    #sys.exit(1)
   # dataset prepare
    global data_dir
    # traindir = data_dir = os.path.join(data_dir, 'train')
    traindir = os.path.join(data_dir, 'train')
    # valdir = data_dir = os.path.join(data_dir,  'val')
    valdir = os.path.join(data_dir, 'val')
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    #dataset split     
    train_data1 = dset.ImageFolder(traindir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_data2 = dset.ImageFolder(valdir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    num_train = len(train_data1)
    num_val = len(train_data2)
    print('# images to train network: %d' % num_train)
    print('# images to validate network: %d' % num_val)
    tensorbard_logger.add_text('# images to train network', str(num_train))
    tensorbard_logger.add_text('# images to validate network', str(num_val))
    
    model = Network(args.init_channels, CLASSES, args.layers, criterion)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
               lr=args.arch_learning_rate, betas=(0.5, 0.999), 
               weight_decay=args.arch_weight_decay)
    
    test_queue = torch.utils.data.DataLoader(
                        valid_data, 
                        batch_size=args.batch_size, 
                        shuffle=False, 
                        pin_memory=True, 
                        num_workers=args.workers)

    train_queue = torch.utils.data.DataLoader(
        train_data1, batch_size=args.batch_size, shuffle=True,
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data2, batch_size=args.batch_size, shuffle=True,
        pin_memory=True, num_workers=args.workers)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    analyser = Analyzer(args, model.module)

    la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                  args.epochs)

    errors_dict = {'train_acc': [], 'train_loss': [], 'valid_acc': [],
                   'valid_loss': []}

    #architect = Architect(model, args)
    lr=args.learning_rate
    scaler = amp.GradScaler()
    total_time = 0
    for epoch in range(args.epochs):
        scheduler.step()
        current_lr = scheduler.get_lr()[0]
        logging.info('Epoch: %d lr: %e', epoch, current_lr)
        tensorbard_logger.add_scalar('train/lr', lr, epoch)
        # if epoch < 5 and args.batch_size > 256:
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr * (epoch + 1) / 5.0
        #     logging.info('Warming-up Epoch: %d, LR: %e', epoch, lr * (epoch + 1) / 5.0)
        #     print(optimizer)
        genotype = model.module.genotype()
        logging.info('genotype = %s', genotype)
        visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=epoch)
        visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=epoch)
        arch_param = model.module.arch_parameters()
        logging.info(F.softmax(arch_param[0], dim=-1))
        logging.info(F.softmax(arch_param[1], dim=-1))
        # training
        train_acc, train_obj, total_time = train(train_queue, valid_queue, model, optimizer, optimizer_a, criterion, lr, epoch, analyser=analyser, local_avg_tracker=la_tracker, scaler=scaler, total_time=total_time)
        print(f'Current time cost after epoch {epoch} is {total_time} s')
        logging.info('Train_acc %f', train_acc)
        tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

        # validation
        if args.early_stop:
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            #test_acc, test_obj = infer(test_queue, model, criterion)
            logging.info('Valid_acc %f', valid_acc)
            tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)
            #logging.info('Test_acc %f', test_acc)
        else:
            if args.epochs - epoch <= 1:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                # test_acc, test_obj = infer(test_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
                tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)
                # logging.info('Test_acc %f', test_acc)

        # utils.save(model, os.path.join(args.save, 'weights.pt'))

        errors_dict['train_acc'].append(100 - train_acc)
        errors_dict['train_loss'].append(train_obj)
        errors_dict['valid_acc'].append(100 - valid_acc)
        errors_dict['valid_loss'].append(valid_obj)

        if la_tracker.stop_search:
            # set the following to the values they had at stop_epoch
            errors_dict['valid_acc'] = errors_dict['valid_acc'][:la_tracker.stop_epoch + 1]
            genotype = la_tracker.stop_genotype
            valid_acc = 100 - errors_dict['valid_acc'][la_tracker.stop_epoch]
            logging.info(
                'Decided to stop the search at epoch %d (Current epoch: %d)',
                la_tracker.stop_epoch, epoch
            )
            logging.info(
                'Validation accuracy at stop epoch: %f', valid_acc
            )
            logging.info(
                'Genotype at stop epoch: %s', genotype
            )
            break

    print('time cost', total_time, 's')
    print('Time End.')

def train(train_queue, valid_queue, model, optimizer, optimizer_a, criterion, lr, epoch, analyser, local_avg_tracker, iteration=1, scaler=None, total_time=0):
    global global_iter
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()

    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)

        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # get a random minibatch from the search queue with replacement
        try:
            input_search, target_search = next(valid_queue_iter)
        except:
            valid_queue_iter = iter(valid_queue)
            input_search, target_search = next(valid_queue_iter)
        input_search = input_search.cuda(non_blocking=True)
        target_search = target_search.cuda(non_blocking=True)

        if total_time == 0:
            print('Time start.')
        time_start = time.time()

        # if epoch >=args.begin:
        optimizer_a.zero_grad()
        with amp.autocast():
            logits = model(input_search)
            loss_a = criterion(logits, target_search)
        # loss_a.sum().backward()
        scaler.scale(loss_a).backward()
        # scaler.unscale_(optimizer_a)
        # torch.nn.utils.clip_grad_norm_(model.module.arch_parameters(), args.grad_clip)
        scaler.step(optimizer_a)
        # nn.utils.clip_grad_norm_(model.module.arch_parameters(), args.grad_clip)
        # optimizer_a.step()
        #architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)

        optimizer.zero_grad()
        with amp.autocast():
            logits = model(input)
            loss = criterion(logits, target)
        # loss.backward()
        # nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
        # optimizer.step()
        scaler.scale(loss).backward()
        # scaler.unscale_(optimizer)
        # torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
        scaler.step(optimizer)

        scaler.update()

        time_end = time.time()
        total_time = total_time + time_end - time_start


        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f', step, objs.avg, top1.avg, top5.avg)
            tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)

        # if epoch == 0 or epoch == 1:
        #     genotype = model.module.genotype()
        #     logging.info(f'genotype after iteration {global_iter}= %s', genotype)
        #     logging.info(f'loss at iteration {global_iter}= %s', loss)
        #     logging.info(f'prec1 at iteration {global_iter}= %s', prec1)
        #     logging.info(f'prec5 at iteration {global_iter}= %s', prec5)


        global_iter = global_iter + 1

    # try:
    #     input_search, target_search = next(valid_queue_iter)
    # except:
    #     valid_queue_iter = iter(valid_queue)
    #     input_search, target_search = next(valid_queue_iter)
    # input_search = input_search.cuda(non_blocking=True)
    # target_search = target_search.cuda(non_blocking=True)

    if args.compute_hessian:
        if (epoch % args.report_freq_hessian == 0) or (epoch == (args.epochs - 1)):
            _data_loader = deepcopy(train_queue)
            input, target = next(iter(_data_loader))

            input = Variable(input, requires_grad=False).cuda()
            target = Variable(target, requires_grad=False).cuda(non_blocking=True)

            # get gradient information
            # param_grads = [p.grad for p in model.parameters() if p.grad is not None]
            # param_grads = torch.cat([x.view(-1) for x in param_grads])
            # param_grads = param_grads.cpu().data.numpy()
            # grad_norm = np.linalg.norm(param_grads)

            # gradient_vector = torch.cat([x.view(-1) for x in gradient_vector])
            # grad_norm = LA.norm(gradient_vector.cpu())
            # logging.info('\nCurrent grad norm based on Train Dataset: %.4f',
            #             grad_norm)

            if not args.debug:
                H = analyser.compute_Hw(input, target, input_search, target_search,
                                        lr, optimizer, False)
                g = analyser.compute_dw(input, target, input_search, target_search,
                                        lr, optimizer, False)
                g = torch.cat([x.view(-1) for x in g])

                del _data_loader

                state = {'epoch': epoch,
                         'H': H.cpu().data.numpy().tolist(),
                         'g': g.cpu().data.numpy().tolist(),
                         # 'g_train': float(grad_norm),
                         # 'eig_train': eigenvalue,
                         }

                with codecs.open(os.path.join(args.save,
                                              'derivatives.json'),
                                 'a', encoding='utf-8') as file:
                    json.dump(state, file, separators=(',', ':'))
                    file.write('\n')

                # early stopping
                ev = max(LA.eigvals(H.cpu().data.numpy()))
            else:
                ev = 0.1
                if epoch >= 8 and iteration == 1:
                    ev = 2.0
            logging.info('CURRENT EV: %f', ev)
            local_avg_tracker.update(epoch, ev, model.module.genotype())
            tensorbard_logger.add_scalar('max_eigen_value', abs(ev), epoch)

            # if args.early_stop and epoch != (args.epochs - 1):
            #   local_avg_tracker.early_stop(epoch, args.factor, args.es_start_epoch,
            #                                args.delta)

    return top1.avg, objs.avg, total_time



def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    for step, (input, target) in enumerate(valid_queue):
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        with amp.autocast():
            with torch.no_grad():
                logits = model(input)
                loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
            tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)

    return top1.avg, objs.avg


if __name__ == '__main__':
    main() 

