import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

from torch.autograd import Variable
from model_search_topo import Network
from architect import Architect

from genotypes import Genotype

import visualize

from copy import deepcopy
import codecs
import json
from numpy import linalg as LA
from analyze import Analyzer


parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
# parser.add_argument('--set', type=str, default='cifar10', help='location of the data corpus')
parser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus')
# parser.add_argument('--batch_size', type=int, default=256, help='batch size') # all ops
# parser.add_argument('--batch_size', type=int, default=397, help='batch size') # 4 ops
# parser.add_argument('--batch_size', type=int, default=575, help='batch size') # 3 ops no conv
# parser.add_argument('--batch_size', type=int, default=680, help='batch size') # skip and none
parser.add_argument('--batch_size', type=int, default=730, help='batch size') # skip
# parser.add_argument('--batch_size', type=int, default=600, help='batch size') # max pool
# parser.add_argument('--batch_size', type=int, default=480, help='batch size') # sep_conv3*3
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
# parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=6e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

parser.add_argument('--report_freq_hessian', type=float, default=1, help='report frequency hessian')
parser.add_argument('--early_stop', type=int, default=1, choices=[0, 1, 2, 3], help='early stop DARTS based on dominant eigenvalue. 0: no 1: yes 2: simulate 3: adaptive regularization')
parser.add_argument('--window', type=int, default=5, help='window size of the local average')
parser.add_argument('--es_start_epoch', type=int, default=0, help='when to start considering early stopping')
parser.add_argument('--delta', type=int, default=4, help='number of previous local averages to consider in early stopping')
parser.add_argument('--factor', type=float, default=1.3, help='early stopping factor')
parser.add_argument('--extra_rollback_epochs', type=int, default=0, help='number of extra rollback epochs when deciding to increse regularization')
parser.add_argument('--compute_hessian', action='store_true', default=True, help='compute or not Hessian')
parser.add_argument('--max_weight_decay', type=float, default=0.03, help='maximum weight decay')
parser.add_argument('--mul_factor', type=float, default=10, help='multiplication factor')
parser.add_argument('--debug', action='store_true', default=False, help='use one-step unrolled validation loss')

args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 256
  args.learning_rate_min = args.learning_rate_min * args.batch_size / 256
  args.arch_learning_rate = args.arch_learning_rate * args.batch_size / 256
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()

args.save = 'exps/search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

logging.info('save file: ' + args.save)
tensorbard_logger.add_text('save file', args.save)

CIFAR_CLASSES = 10
if args.set=='cifar100':
    CIFAR_CLASSES = 100

global_iter = 0

def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
  # torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  tensorbard_logger.add_text('gpu device', str(args.gpu))
  logging.info("args = %s", args)
  tensorbard_logger.add_text('args', str(args))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
  # utils.load(model, os.path.join('exps/search-EXP-20200605-165712/', 'weights.pt'))
  model = model.cuda()
  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
  tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  if args.set=='cifar100':
      train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
  else:
      train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

  num_train = len(train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

  architect = Architect(model, args)

  analyser = Analyzer(args, model)

  la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                args.epochs)

  errors_dict = {'train_acc': [], 'train_loss': [], 'valid_acc': [],
                 'valid_loss': []}

  for epoch in range(args.epochs):
    scheduler.step()
    lr = scheduler.get_lr()[0]
    logging.info('epoch %d lr %e', epoch, lr)
    tensorbard_logger.add_scalar('train/lr', lr, epoch)

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)
    visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=epoch)
    visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=epoch)

    print(F.softmax(model.alphas_normal, dim=-1))
    print(F.softmax(model.alphas_reduce, dim=-1))
    print(F.softmax(model.betas_normal[2:5], dim=-1))
    #model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    # training
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, analyser=analyser, local_avg_tracker=la_tracker)
    logging.info('train_acc %f', train_acc)
    tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

    # validation
    if args.early_stop:
      valid_acc, valid_obj = infer(valid_queue, model, criterion)
      logging.info('valid_acc %f', valid_acc)
      tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)
    else:
      if args.epochs - epoch <= 1:
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)
        tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)

    # utils.save(model, os.path.join(args.save, 'weights.pt'))

    errors_dict['train_acc'].append(100 - train_acc)
    errors_dict['train_loss'].append(train_obj)
    errors_dict['valid_acc'].append(100 - valid_acc)
    errors_dict['valid_loss'].append(valid_obj)

    if la_tracker.stop_search:
      # set the following to the values they had at stop_epoch
      errors_dict['valid_acc'] = errors_dict['valid_acc'][:la_tracker.stop_epoch + 1]
      genotype = la_tracker.stop_genotype
      valid_acc = 100 - errors_dict['valid_acc'][la_tracker.stop_epoch]
      logging.info(
        'Decided to stop the search at epoch %d (Current epoch: %d)',
        la_tracker.stop_epoch, epoch
      )
      logging.info(
        'Validation accuracy at stop epoch: %f', valid_acc
      )
      logging.info(
        'Genotype at stop epoch: %s', genotype
      )
      break


def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, analyser, local_avg_tracker, iteration=1):
  global global_iter
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()

  for step, (input, target) in enumerate(train_queue):
    model.train()
    n = input.size(0)
    input = input.cuda()
    target = target.cuda(non_blocking=True)

    # get a random minibatch from the search queue with replacement
    #input_search, target_search = next(iter(valid_queue))
    try:
      input_search, target_search = next(valid_queue_iter)
    except:
      valid_queue_iter = iter(valid_queue)
      input_search, target_search = next(valid_queue_iter)
    input_search = input_search.cuda()
    target_search = target_search.cuda(non_blocking=True)

    if epoch>=0:
      architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)

    optimizer.zero_grad()
    logits = model(input)
    loss = criterion(logits, target)

    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    optimizer.step()

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    objs.update(loss.data.item(), n)
    top1.update(prec1.data.item(), n)
    top5.update(prec5.data.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
      tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)

    global_iter = global_iter + 1

  if args.compute_hessian:
    if (epoch % args.report_freq_hessian == 0) or (epoch == (args.epochs - 1)):
      _data_loader = deepcopy(train_queue)
      input, target = next(iter(_data_loader))

      input = Variable(input, requires_grad=False).cuda()
      target = Variable(target, requires_grad=False).cuda(non_blocking=True)

      # get gradient information
      #param_grads = [p.grad for p in model.parameters() if p.grad is not None]
      #param_grads = torch.cat([x.view(-1) for x in param_grads])
      #param_grads = param_grads.cpu().data.numpy()
      #grad_norm = np.linalg.norm(param_grads)

      #gradient_vector = torch.cat([x.view(-1) for x in gradient_vector])
      #grad_norm = LA.norm(gradient_vector.cpu())
      #logging.info('\nCurrent grad norm based on Train Dataset: %.4f',
      #             grad_norm)

      if not args.debug:
        H = analyser.compute_Hw(input, target, input_search, target_search,
                                lr, optimizer, False)
        g = analyser.compute_dw(input, target, input_search, target_search,
                                lr, optimizer, False)
        g = torch.cat([x.view(-1) for x in g])

        del _data_loader

        state = {'epoch': epoch,
                 'H': H.cpu().data.numpy().tolist(),
                 'g': g.cpu().data.numpy().tolist(),
                 #'g_train': float(grad_norm),
                 #'eig_train': eigenvalue,
                }

        with codecs.open(os.path.join(args.save,
                                      'derivatives.json'),
                                      'a', encoding='utf-8') as file:
          json.dump(state, file, separators=(',', ':'))
          file.write('\n')

        # early stopping
        ev = max(LA.eigvals(H.cpu().data.numpy()))
      else:
        ev = 0.1
        if epoch >= 8 and iteration==1:
          ev = 2.0
      logging.info('CURRENT EV: %f', ev)
      local_avg_tracker.update(epoch, ev, model.genotype())
      tensorbard_logger.add_scalar('max_eigen_value', abs(ev), epoch)

      # if args.early_stop and epoch != (args.epochs - 1):
      #   local_avg_tracker.early_stop(epoch, args.factor, args.es_start_epoch,
      #                                args.delta)

  return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.eval()

  with torch.no_grad():    
    for step, (input, target) in enumerate(valid_queue):
      input = input.cuda()
      target = target.cuda(non_blocking=True)
      logits = model(input)
      loss = criterion(logits, target)

      prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
      n = input.size(0)
      objs.update(loss.data.item(), n)
      top1.update(prec1.data.item(), n)
      top5.update(prec5.data.item(), n)

      if step % args.report_freq == 0:
        logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
        tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)

  return top1.avg, objs.avg


if __name__ == '__main__':
  main() 

