import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import genotypes
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torch_optimizer as optim

import multiprocessing

from torch.autograd import Variable
from model import NetworkCIFAR as Network

import visualize

import random

import torch.cuda.amp as amp


parser = argparse.ArgumentParser("cifar")
# parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--data', type=str, default='/dev/shm/data', help='location of the data corpus')
parser.add_argument('--set', type=str, default='cifar10', help='location of the data corpus')
# parser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus')
# parser.add_argument('--batch_size', type=int, default=1032, help='batch size')
# parser.add_argument('--batch_size', type=int, default=990, help='batch size')
# parser.add_argument('--batch_size', type=int, default=810, help='batch size')
# parser.add_argument('--batch_size', type=int, default=750, help='batch size')
# parser.add_argument('--batch_size', type=int, default=96, help='batch size')
# parser.add_argument('--batch_size', type=int, default=220, help='batch size')
# parser.add_argument('--batch_size', type=int, default=85, help='batch size')
# parser.add_argument('--batch_size', type=int, default=75, help='batch size')
# parser.add_argument('--batch_size', type=int, default=170, help='batch size')
# parser.add_argument('--batch_size', type=int, default=140, help='batch size')
# parser.add_argument('--batch_size', type=int, default=240, help='batch size')
# parser.add_argument('--batch_size', type=int, default=270, help='batch size')
parser.add_argument('--batch_size', type=int, default=200, help='batch size')
# parser.add_argument('--batch_size', type=int, default=330, help='batch size')
# parser.add_argument('--batch_size', type=int, default=55, help='batch size')
# parser.add_argument('--batch_size', type=int, default=88, help='batch size')
# parser.add_argument('--batch_size', type=int, default=80, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
# parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
# parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument('--epochs', type=int, default=1200, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='PCDARTS', help='which architecture to use')
# parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 96
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()

args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

CIFAR_CLASSES = 10

if args.set=='cifar100':
    CIFAR_CLASSES = 100

global_iter = 0

def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
  # torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  # logging.info('gpu device = %s' % args.gpu)
  # tensorbard_logger.add_text('gpu device', str(args.gpu))
  logging.info("args = %s", args)
  tensorbard_logger.add_text('args', str(args))
  num_gpus = torch.cuda.device_count()
  logging.info("num_gpus = %s", num_gpus)
  tensorbard_logger.add_text('num_gpus', str(num_gpus))

  # genotype = eval("genotypes.%s" % args.arch)
  from genotypes import Genotype
  # joint search
  # genotype = Genotype(
  #   normal=[('dil_conv_3x3', 0), ('skip_connect', 1), ('skip_connect', 2), ('sep_conv_5x5', 0), ('dil_conv_3x3', 0),
  #           ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('dil_conv_5x5', 2)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_5x5', 0), ('max_pool_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 2), ('sep_conv_5x5', 2),
  #           ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('dil_conv_3x3', 3)], reduce_concat=range(2, 6))

  # joint search (not the one found)
  # genotype = Genotype(
  #   normal=[('sep_conv_3x3', 1), ('dil_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1),
  #           ('max_pool_3x3', 0), ('avg_pool_3x3', 0), ('sep_conv_5x5', 2)], normal_concat=range(2, 6),
  #   reduce=[('dil_conv_5x5', 1), ('sep_conv_3x3', 0), ('skip_connect', 1), ('dil_conv_5x5', 2), ('dil_conv_5x5', 2),
  #           ('dil_conv_5x5', 0), ('dil_conv_5x5', 3), ('skip_connect', 0)], reduce_concat=range(2, 6))

  # 3 ops
  # genotype = Genotype(
  #   normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('avg_pool_3x3', 1), ('sep_conv_3x3', 2),
  #           ('dil_conv_3x3', 3), ('dil_conv_5x5', 4), ('sep_conv_5x5', 2)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 1), ('dil_conv_5x5', 2), ('dil_conv_5x5', 2),
  #           ('dil_conv_5x5', 1), ('dil_conv_5x5', 1), ('sep_conv_5x5', 3)], reduce_concat=range(2, 6))

  # 4 ops
  # genotype = Genotype(
  #   normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_5x5', 1), ('sep_conv_3x3', 2),
  #           ('sep_conv_3x3', 1), ('max_pool_3x3', 1), ('max_pool_3x3', 3)], normal_concat=range(2, 6),
  #   reduce=[('dil_conv_5x5', 1), ('dil_conv_3x3', 0), ('max_pool_3x3', 2), ('dil_conv_3x3', 1), ('sep_conv_3x3', 1),
  #           ('dil_conv_5x5', 2), ('max_pool_3x3', 4), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6))

  # skip
  # genotype = Genotype(
  #   normal=[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 2), ('sep_conv_5x5', 1),
  #           ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_5x5', 1), ('dil_conv_3x3', 0), ('sep_conv_3x3', 1), ('max_pool_3x3', 0), ('sep_conv_3x3', 0),
  #           ('max_pool_3x3', 3), ('avg_pool_3x3', 0), ('dil_conv_5x5', 3)], reduce_concat=range(2, 6))

  # sep_conv3
  # genotype = Genotype(
  #   normal=[('skip_connect', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0),
  #           ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], normal_concat=range(2, 6),
  #   reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 2), ('max_pool_3x3', 1), ('sep_conv_5x5', 1),
  #           ('max_pool_3x3', 0), ('skip_connect', 1), ('dil_conv_3x3', 2)], reduce_concat=range(2, 6))

  # skip & none
  # genotype = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6))

  geno_script = \
  "Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], normal_concat=range(2, 6), reduce=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 0), ('sep_conv_3x3', 3), ('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], reduce_concat=range(2, 6))"
  # PRIMITIVES = [
  #   'max_pool_3x3',
  #   'avg_pool_3x3',
  #   'skip_connect',
  #   'sep_conv_3x3',
  #   'sep_conv_5x5',
  #   'dil_conv_3x3',
  #   'dil_conv_5x5'
  # ]
  # while geno_script.find('skip_connect')!=-1:
  #   geno_script = geno_script.replace('skip_connect', random.choice(PRIMITIVES), 1)
  # geno_script = geno_script.replace('skip_connect', 'sep_conv_3x3')
  # geno_script = geno_script.replace('max_pool_3x3', 'sep_conv_3x3')
  genotype = eval(geno_script)
  print('architecture: ' + str(genotype))
  visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=0)
  visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=0)

  # genotype = Genotype(
  #   normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0),
  #           ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6),
  #   reduce=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2),
  #           ('sep_conv_3x3', 3), ('sep_conv_3x3', 3), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6))

  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
  if num_gpus > 1:
    # model = nn.DataParallel(model, device_ids=[int(x) for x in args.gpu.split(',')])
    model = nn.DataParallel(model)
    model = model.cuda()
  else:
    model = model.cuda()
  # utils.load(model, os.path.join('eval-EXP-20200701-174306/', 'weights.pt'))

  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
  tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  # optimizer = optim.RAdam(
  #   model.parameters(),
  #   lr=args.learning_rate,
  #   betas=(args.momentum, 0.999),
  #   eps=1e-8,
  #   weight_decay=args.weight_decay
  #   )

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  if args.set=='cifar100':
      train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
      valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
  else:
      train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
      valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
  #train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
  #valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

  train_queue = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=multiprocessing.cpu_count())

  valid_queue = torch.utils.data.DataLoader(
    valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=multiprocessing.cpu_count())

  # train_queue = torch.utils.data.DataLoader(
  #     train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)
  #
  # valid_queue = torch.utils.data.DataLoader(
  #     valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
  # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 600)

  max_val_acc = 96
  scaler = amp.GradScaler()
  best_acc = 0
  best_acc_epoch = 0

  for epoch in range(args.epochs):
    logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    tensorbard_logger.add_scalar('train/lr', scheduler.get_lr()[0], epoch)
    if num_gpus > 1:
      model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    else:
      model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    train_acc, train_obj = train(train_queue, model, criterion, optimizer, scaler)
    logging.info('train_acc %f', train_acc)
    tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info('valid_acc %f', valid_acc)
    tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)

    # if valid_acc > max_val_acc - 0.1:
    #   utils.save(model, os.path.join(args.save, 'weights_' + str(epoch) + '.pt'))
    #
    # if valid_acc > max_val_acc:
    #   max_val_acc = valid_acc
    #
    # utils.save(model, os.path.join(args.save, 'weights.pt'))

    is_best = False
    if valid_acc > best_acc:
      best_acc = valid_acc
      best_acc_epoch = epoch
      is_best = True
    utils.save_checkpoint({
      'epoch': epoch + 1,
      'state_dict': model.state_dict(),
      'best_acc': best_acc,
      'optimizer': optimizer.state_dict(),
    }, is_best, args.save)

    scheduler.step()

  logging.info('Best Top1 Accuracy: %f, at epoch %d', best_acc, best_acc_epoch)
  tensorbard_logger.add_text('Best_Top1_Accuracy',
                             'Best_Top1_Accuracy_{}_at_epoch_{}'.format(best_acc, best_acc_epoch))


def train(train_queue, model, criterion, optimizer, scaler):
  global global_iter
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.train()

  for step, (input, target) in enumerate(train_queue):
    input = input.cuda()
    target = target.cuda(non_blocking=True)

    optimizer.zero_grad()
    with amp.autocast():
      logits, logits_aux = model(input)
      loss = criterion(logits, target)
      if args.auxiliary:
        loss_aux = criterion(logits_aux, target)
        loss += args.auxiliary_weight*loss_aux
    # loss.backward()
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    scaler.step(optimizer)
    scaler.update()
    # nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    # optimizer.step()

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    n = input.size(0)
    objs.update(loss.data.item(), n)
    top1.update(prec1.data.item(), n)
    top5.update(prec5.data.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
      tensorbard_logger.add_scalar('train/train_objs_'+str(step), objs.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top1_'+str(step), top1.avg, global_iter)
      tensorbard_logger.add_scalar('train/train_top5_'+str(step), top5.avg, global_iter)

    global_iter = global_iter + 1

  return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.eval()

  with torch.no_grad():    
    for step, (input, target) in enumerate(valid_queue):
      input = input.cuda()
      target = target.cuda(non_blocking=True)
      with amp.autocast():
        logits,_ = model(input)
        loss = criterion(logits, target)

      prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
      n = input.size(0)
      objs.update(loss.data.item(), n)
      top1.update(prec1.data.item(), n)
      top5.update(prec5.data.item(), n)

      if step % args.report_freq == 0:
        logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
        tensorbard_logger.add_scalar('valid/valid_objs_'+str(step), objs.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top1_'+str(step), top1.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top5_'+str(step), top5.avg, global_iter)

  return top1.avg, objs.avg


if __name__ == '__main__':
  main() 

