import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

from torch.autograd import Variable
from model_search_op import Network
from architect import Architect

from genotypes import Genotype

import visualize

from copy import deepcopy
import codecs
import json
from numpy import linalg as LA
from analyze import Analyzer

import time
from torch.cuda.amp import autocast as autocast
import torch.cuda.amp as amp


parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--set', type=str, default='cifar10', help='location of the data corpus')
# parser.add_argument('--batch_size', type=int, default=335, help='batch size')
parser.add_argument('--batch_size', type=int, default=590, help='batch size')
# parser.add_argument('--batch_size', type=int, default=720, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=6e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

parser.add_argument('--report_freq_hessian', type=float, default=2, help='report frequency hessian')
parser.add_argument('--early_stop', type=int, default=1, choices=[0, 1, 2, 3], help='early stop DARTS based on dominant eigenvalue. 0: no 1: yes 2: simulate 3: adaptive regularization')
parser.add_argument('--window', type=int, default=5, help='window size of the local average')
parser.add_argument('--es_start_epoch', type=int, default=10, help='when to start considering early stopping')
parser.add_argument('--delta', type=int, default=4, help='number of previous local averages to consider in early stopping')
parser.add_argument('--factor', type=float, default=1.3, help='early stopping factor')
parser.add_argument('--extra_rollback_epochs', type=int, default=0, help='number of extra rollback epochs when deciding to increse regularization')
parser.add_argument('--compute_hessian', action='store_true', default=False, help='compute or not Hessian')
parser.add_argument('--max_weight_decay', type=float, default=0.03, help='maximum weight decay')
parser.add_argument('--mul_factor', type=float, default=10, help='multiplication factor')
parser.add_argument('--debug', action='store_true', default=False, help='use one-step unrolled validation loss')

args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 256
  args.learning_rate_min = args.learning_rate_min * args.batch_size / 256
  args.arch_learning_rate = args.arch_learning_rate * args.batch_size / 256
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()

args.save = 'exps/search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

logging.info('save file: ' + args.save)
tensorbard_logger.add_text('save file', args.save)

CIFAR_CLASSES = 10
if args.set=='cifar100':
    CIFAR_CLASSES = 100

global_iter = 0

def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
  # torch.cuda.set_device(0)
  # torch.cuda.set_device(args.gpu)
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  print(device)
  torch.cuda.set_device(device)

  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  tensorbard_logger.add_text('gpu device', str(args.gpu))
  logging.info("args = %s", args)
  tensorbard_logger.add_text('args', str(args))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  # genotype = Genotype(
  #   normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('skip_connect', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1),
  #           ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 4)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0),
  #           ('sep_conv_3x3', 2), ('skip_connect', 1), ('sep_conv_3x3', 3)], reduce_concat=range(2, 6))
  genotype = Genotype(normal=[('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 1), ('skip_connect', 0)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 1), ('skip_connect', 2), ('skip_connect', 1), ('skip_connect', 3), ('skip_connect', 1), ('skip_connect', 0)], reduce_concat=range(2, 6))

  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion, topo=genotype)
  # utils.load(model, os.path.join('exps/search-EXP-20200717-005146/', 'weights.pt'))
  model = model.cuda()
  # model = model.to(memory_format=torch.channels_last)
  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
  tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  if args.set=='cifar100':
      train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
  else:
      train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

  num_train = len(train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

  architect = Architect(model, args)

  scaler = amp.GradScaler()

  analyser = Analyzer(args, model)

  la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                args.epochs)

  errors_dict = {'train_acc': [], 'train_loss': [], 'valid_acc': [],
                 'valid_loss': []}

  total_time = []

  for epoch in range(args.epochs):
    scheduler.step()
    lr = scheduler.get_lr()[0]
    logging.info('epoch %d lr %e', epoch, lr)
    tensorbard_logger.add_scalar('train/lr', lr, epoch)

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)
    visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=epoch)
    visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=epoch)

    print(F.softmax(model.alphas_normal, dim=-1))
    print(F.softmax(model.alphas_reduce, dim=-1))
    print(F.softmax(model.betas_normal[2:5], dim=-1))
    #model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    # training
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, analyser=analyser, local_avg_tracker=la_tracker, scaler=scaler, total_time=total_time)
    logging.info('train_acc %f', train_acc)
    tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)
    tensorbard_logger.add_scalar('train/search_time(s)', sum(total_time), epoch)
    if args.epochs - epoch <= 1:
      logging.info(f'Total search time is {sum(total_time)} s')
      tensorbard_logger.add_text('Total_search_time', f'{sum(total_time)} s')

    # validation
    if args.early_stop:
      valid_acc, valid_obj = infer(valid_queue, model, criterion)
      logging.info('valid_acc %f', valid_acc)
      tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)
    else:
      if args.epochs-epoch<=1:
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)
        tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)

    utils.save(model, os.path.join(args.save, 'weights.pt'))

    errors_dict['train_acc'].append(100 - train_acc)
    errors_dict['train_loss'].append(train_obj)
    errors_dict['valid_acc'].append(100 - valid_acc)
    errors_dict['valid_loss'].append(valid_obj)

    if la_tracker.stop_search:
        # set the following to the values they had at stop_epoch
        errors_dict['valid_acc'] = errors_dict['valid_acc'][:la_tracker.stop_epoch + 1]
        genotype = la_tracker.stop_genotype
        valid_acc = 100 - errors_dict['valid_acc'][la_tracker.stop_epoch]
        logging.info(
          'Decided to stop the search at epoch %d (Current epoch: %d)',
          la_tracker.stop_epoch, epoch
        )
        logging.info(
          'Validation accuracy at stop epoch: %f', valid_acc
        )
        logging.info(
          'Genotype at stop epoch: %s', genotype
        )
        break


def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, analyser, local_avg_tracker, iteration=1, scaler=None, total_time=[]):
  global global_iter
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()

  for step, (input, target) in enumerate(train_queue):
    model.train()
    n = input.size(0)
    input = input.cuda()
    # input = input.to(memory_format=torch.channels_last)
    target = target.cuda(non_blocking=True)

    # get a random minibatch from the search queue with replacement
    #input_search, target_search = next(iter(valid_queue))
    try:
      input_search, target_search = next(valid_queue_iter)
    except:
      valid_queue_iter = iter(valid_queue)
      input_search, target_search = next(valid_queue_iter)
    input_search = input_search.cuda()
    target_search = target_search.cuda(non_blocking=True)

    if len(total_time) == 0:
      logging.info('Time start.')
    time_start = time.time()

    if epoch>=15:
    # if epoch >= 0:
      architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled, scaler=scaler)

    optimizer.zero_grad()
    with autocast():
      logits = model(input)
      loss = criterion(logits, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    time_end = time.time()
    total_time.append(time_end - time_start)

    # optimizer.zero_grad()
    # logits = model(input)
    # loss = criterion(logits, target)
    #
    # loss.backward()
    # nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    # optimizer.step()

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    objs.update(loss.data.item(), n)
    top1.update(prec1.data.item(), n)
    top5.update(prec5.data.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
      tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)

    global_iter = global_iter + 1

  if args.compute_hessian:
    if (epoch % args.report_freq_hessian == 0) or (epoch == (args.epochs - 1)):
      _data_loader = deepcopy(train_queue)
      input, target = next(iter(_data_loader))

      input = Variable(input, requires_grad=False).cuda()
      target = Variable(target, requires_grad=False).cuda(non_blocking=True)

      # get gradient information
      #param_grads = [p.grad for p in model.parameters() if p.grad is not None]
      #param_grads = torch.cat([x.view(-1) for x in param_grads])
      #param_grads = param_grads.cpu().data.numpy()
      #grad_norm = np.linalg.norm(param_grads)

      #gradient_vector = torch.cat([x.view(-1) for x in gradient_vector])
      #grad_norm = LA.norm(gradient_vector.cpu())
      #logging.info('\nCurrent grad norm based on Train Dataset: %.4f',
      #             grad_norm)

      if not args.debug:
        H = analyser.compute_Hw(input, target, input_search, target_search,
                                lr, optimizer, False)
        g = analyser.compute_dw(input, target, input_search, target_search,
                                lr, optimizer, False)
        g = torch.cat([x.view(-1) for x in g])

        del _data_loader

        state = {'epoch': epoch,
                 'H': H.cpu().data.numpy().tolist(),
                 'g': g.cpu().data.numpy().tolist(),
                 #'g_train': float(grad_norm),
                 #'eig_train': eigenvalue,
                }

        with codecs.open(os.path.join(args.save,
                                      'derivatives.json'),
                                      'a', encoding='utf-8') as file:
          json.dump(state, file, separators=(',', ':'))
          file.write('\n')

        # early stopping
        ev = max(LA.eigvals(H.cpu().data.numpy()))
      else:
        ev = 0.1
        if epoch >= 8 and iteration==1:
          ev = 2.0
      logging.info('CURRENT EV: %f', ev)
      local_avg_tracker.update(epoch, ev, model.genotype())
      tensorbard_logger.add_scalar('max_eigen_value', abs(ev), step)

      # if args.early_stop and epoch != (args.epochs - 1):
      #   local_avg_tracker.early_stop(epoch, args.factor, args.es_start_epoch,
      #                                args.delta)

  return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.eval()

  with torch.no_grad():    
    for step, (input, target) in enumerate(valid_queue):
      input = input.cuda()
      # input = input.to(memory_format=torch.channels_last)
      target = target.cuda(non_blocking=True)

      with autocast():
        logits = model(input)
        loss = criterion(logits, target)
      # logits = model(input)
      # loss = criterion(logits, target)

      prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
      n = input.size(0)
      objs.update(loss.data.item(), n)
      top1.update(prec1.data.item(), n)
      top5.update(prec5.data.item(), n)

      if step % args.report_freq == 0:
        logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
        tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)

  return top1.avg, objs.avg


if __name__ == '__main__':
  main() 

