import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import genotypes
from genotypes import Genotype
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torch_optimizer as optim

import multiprocessing

from torch.autograd import Variable
from model import NetworkCIFAR as Network

import visualize

import random

import torch.cuda.amp as amp

from sync_batchnorm import convert_model


parser = argparse.ArgumentParser("cifar")
# parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--data', type=str, default='/dev/shm/data', help='location of the data corpus')
parser.add_argument('--set', type=str, default='cifar10', help='location of the data corpus')
# parser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus')
# parser.add_argument('--batch_size', type=int, default=1032, help='batch size')
# parser.add_argument('--batch_size', type=int, default=990, help='batch size')
# parser.add_argument('--batch_size', type=int, default=810, help='batch size')
# parser.add_argument('--batch_size', type=int, default=750, help='batch size')
# parser.add_argument('--batch_size', type=int, default=96, help='batch size')
# parser.add_argument('--batch_size', type=int, default=220, help='batch size')
# parser.add_argument('--batch_size', type=int, default=85, help='batch size')
# parser.add_argument('--batch_size', type=int, default=75, help='batch size')
# parser.add_argument('--batch_size', type=int, default=140, help='batch size')
parser.add_argument('--batch_size', type=int, default=400, help='batch size')
# parser.add_argument('--batch_size', type=int, default=500, help='batch size')
# parser.add_argument('--batch_size', type=int, default=140, help='batch size')
# parser.add_argument('--batch_size', type=int, default=240, help='batch size')
# parser.add_argument('--batch_size', type=int, default=290, help='batch size')
# parser.add_argument('--batch_size', type=int, default=120, help='batch size')
# parser.add_argument('--batch_size', type=int, default=210, help='batch size')
# parser.add_argument('--batch_size', type=int, default=330, help='batch size')
# parser.add_argument('--batch_size', type=int, default=55, help='batch size')
# parser.add_argument('--batch_size', type=int, default=88, help='batch size')
# parser.add_argument('--batch_size', type=int, default=80, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
# parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
# parser.add_argument('--epochs', type=int, default=1200, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='PCDARTS', help='which architecture to use')
# parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
# parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--grad_clip', type=float, default=20, help='gradient clipping')
args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 96
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()

args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

CIFAR_CLASSES = 10

if args.set=='cifar100':
    CIFAR_CLASSES = 100

global_iter = 0

def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
  # torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  # logging.info('gpu device = %s' % args.gpu)
  # tensorbard_logger.add_text('gpu device', str(args.gpu))
  logging.info("args = %s", args)
  tensorbard_logger.add_text('args', str(args))
  num_gpus = torch.cuda.device_count()
  logging.info("num_gpus = %s", num_gpus)
  tensorbard_logger.add_text('num_gpus', str(num_gpus))

  # genotype = eval("genotypes.%s" % args.arch)
  from genotypes import Genotype
  # joint search
  # genotype = Genotype(
  #   normal=[('dil_conv_3x3', 0), ('skip_connect', 1), ('skip_connect', 2), ('sep_conv_5x5', 0), ('dil_conv_3x3', 0),
  #           ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('dil_conv_5x5', 2)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_5x5', 0), ('max_pool_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 2), ('sep_conv_5x5', 2),
  #           ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('dil_conv_3x3', 3)], reduce_concat=range(2, 6))

  # joint search (not the one found)
  # genotype = Genotype(
  #   normal=[('sep_conv_3x3', 1), ('dil_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1),
  #           ('max_pool_3x3', 0), ('avg_pool_3x3', 0), ('sep_conv_5x5', 2)], normal_concat=range(2, 6),
  #   reduce=[('dil_conv_5x5', 1), ('sep_conv_3x3', 0), ('skip_connect', 1), ('dil_conv_5x5', 2), ('dil_conv_5x5', 2),
  #           ('dil_conv_5x5', 0), ('dil_conv_5x5', 3), ('skip_connect', 0)], reduce_concat=range(2, 6))

  # 3 ops
  # genotype = Genotype(
  #   normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('avg_pool_3x3', 1), ('sep_conv_3x3', 2),
  #           ('dil_conv_3x3', 3), ('dil_conv_5x5', 4), ('sep_conv_5x5', 2)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 1), ('dil_conv_5x5', 2), ('dil_conv_5x5', 2),
  #           ('dil_conv_5x5', 1), ('dil_conv_5x5', 1), ('sep_conv_5x5', 3)], reduce_concat=range(2, 6))

  # 4 ops
  # genotype = Genotype(
  #   normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_5x5', 1), ('sep_conv_3x3', 2),
  #           ('sep_conv_3x3', 1), ('max_pool_3x3', 1), ('max_pool_3x3', 3)], normal_concat=range(2, 6),
  #   reduce=[('dil_conv_5x5', 1), ('dil_conv_3x3', 0), ('max_pool_3x3', 2), ('dil_conv_3x3', 1), ('sep_conv_3x3', 1),
  #           ('dil_conv_5x5', 2), ('max_pool_3x3', 4), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6))

  # skip
  # genotype = Genotype(
  #   normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 2), ('sep_conv_5x5', 1),
  #           ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_5x5', 1), ('dil_conv_3x3', 0), ('sep_conv_3x3', 1), ('max_pool_3x3', 0), ('sep_conv_3x3', 0),
  #           ('max_pool_3x3', 3), ('avg_pool_3x3', 0), ('dil_conv_5x5', 3)], reduce_concat=range(2, 6))

  # sep_conv3
  # genotype = Genotype(
  #   normal=[('skip_connect', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0),
  #           ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], normal_concat=range(2, 6),
  #   reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 2), ('max_pool_3x3', 1), ('sep_conv_5x5', 1),
  #           ('max_pool_3x3', 0), ('skip_connect', 1), ('dil_conv_3x3', 2)], reduce_concat=range(2, 6))

  # skip & none
  # genotype = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6))

  geno_script = \
  "Genotype(normal=[('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 2), ('skip_connect', 0), ('skip_connect', 2), ('skip_connect', 4), ('skip_connect', 0)], normal_concat=range(2, 6), reduce=[('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 0), ('skip_connect', 2), ('skip_connect', 1), ('skip_connect', 3), ('skip_connect', 0), ('skip_connect', 1)], reduce_concat=range(2, 6))"
  # PRIMITIVES = [
  #   'max_pool_3x3',
  #   'avg_pool_3x3',
  #   'skip_connect',
  #   'sep_conv_3x3',
  #   'sep_conv_5x5',
  #   'dil_conv_3x3',
  #   'dil_conv_5x5'
  # ]
  # while geno_script.find('skip_connect')!=-1:
  #   geno_script = geno_script.replace('skip_connect', random.choice(PRIMITIVES), 1)
  geno_script = geno_script.replace('skip_connect', 'sep_conv_3x3')
  # geno_script = geno_script.replace('max_pool_3x3', 'sep_conv_3x3')
  print('architecture: ' + str(geno_script))
  genotype = eval(geno_script)
  # print('architecture: ' + str(genotype))
  visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=0)
  visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=0)

  # genotype = Genotype(
  #   normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0),
  #           ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2),
  #           ('sep_conv_3x3', 3), ('sep_conv_3x3', 3), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6))

  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
  if num_gpus > 1:
    # model = nn.DataParallel(model, device_ids=[int(x) for x in args.gpu.split(',')])
    model = nn.DataParallel(model)
    # model = convert_model(model)
    model = model.cuda()
  else:
    model = model.cuda()
  # utils.load(model, os.path.join('eval-EXP-20200701-174306/', 'weights.pt'))

  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
  tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  # optimizer = optim.RAdam(
  #   model.parameters(),
  #   lr=args.learning_rate,
  #   betas=(args.momentum, 0.999),
  #   eps=1e-8,
  #   weight_decay=args.weight_decay
  #   )

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  if args.set=='cifar100':
      train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
      valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
  else:
      train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
      valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
  #train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
  #valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

  train_queue = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=multiprocessing.cpu_count())

  valid_queue = torch.utils.data.DataLoader(
    valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=multiprocessing.cpu_count())

  # train_queue = torch.utils.data.DataLoader(
  #     train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)
  #
  # valid_queue = torch.utils.data.DataLoader(
  #     valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
  # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 600)

  max_val_acc = 96
  scaler = amp.GradScaler()
  best_acc = 0
  best_acc_epoch = 0

  for epoch in range(args.epochs):
    logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    tensorbard_logger.add_scalar('train/lr', scheduler.get_lr()[0], epoch)
    if num_gpus > 1:
      model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    else:
      model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    train_acc, train_obj = train(train_queue, model, criterion, optimizer, scaler)
    logging.info('train_acc %f', train_acc)
    tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info('valid_acc %f', valid_acc)
    tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)

    # if valid_acc > max_val_acc - 0.1:
    #   utils.save(model, os.path.join(args.save, 'weights_' + str(epoch) + '.pt'))
    #
    # if valid_acc > max_val_acc:
    #   max_val_acc = valid_acc
    #
    # utils.save(model, os.path.join(args.save, 'weights.pt'))

    is_best = False
    if valid_acc > best_acc:
      best_acc = valid_acc
      best_acc_epoch = epoch
      is_best = True
    utils.save_checkpoint({
      'epoch': epoch + 1,
      'state_dict': model.state_dict(),
      'best_acc': best_acc,
      'optimizer': optimizer.state_dict(),
    }, is_best, args.save)

    scheduler.step()

  logging.info('Best Top1 Accuracy: %f, at epoch %d', best_acc, best_acc_epoch)
  tensorbard_logger.add_text('Best_Top1_Accuracy',
                             'Best_Top1_Accuracy_{}_at_epoch_{}'.format(best_acc, best_acc_epoch))


def train(train_queue, model, criterion, optimizer, scaler):
  global global_iter
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.train()

  for step, (input, target) in enumerate(train_queue):
    input = input.cuda()
    target = target.cuda(non_blocking=True)

    optimizer.zero_grad()
    with amp.autocast():
      logits, logits_aux = model(input)
      loss = criterion(logits, target)
      if args.auxiliary:
        loss_aux = criterion(logits_aux, target)
        loss += args.auxiliary_weight*loss_aux
    # loss.backward()
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    scaler.step(optimizer)
    scaler.update()
    # nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    # optimizer.step()

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    n = input.size(0)
    objs.update(loss.data.item(), n)
    top1.update(prec1.data.item(), n)
    top5.update(prec5.data.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
      tensorbard_logger.add_scalar('train/train_objs_'+str(step), objs.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top1_'+str(step), top1.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top5_'+str(step), top5.avg, global_iter)

    global_iter = global_iter + 1

  return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.eval()

  with torch.no_grad():    
    for step, (input, target) in enumerate(valid_queue):
      input = input.cuda()
      target = target.cuda(non_blocking=True)
      with amp.autocast():
        logits,_ = model(input)
        loss = criterion(logits, target)

      prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
      n = input.size(0)
      objs.update(loss.data.item(), n)
      top1.update(prec1.data.item(), n)
      top5.update(prec5.data.item(), n)

      if step % args.report_freq == 0:
        logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
        tensorbard_logger.add_scalar('valid/valid_objs_'+str(step), objs.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top1_'+str(step), top1.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top5_'+str(step), top5.avg, global_iter)

  return top1.avg, objs.avg


if __name__ == '__main__':
  main() 

