import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

from torch.autograd import Variable
from model_search_topo import Network
from architect import Architect

import visualize

import time
from torch.cuda.amp import autocast as autocast
import torch.cuda.amp as amp


parser = argparse.ArgumentParser("cifar")
# parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--data', type=str, default='/dev/shm/data', help='location of the data corpus')
parser.add_argument('--set', type=str, default='cifar10', help='location of the data corpus')
# parser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus')
# parser.add_argument('--batch_size', type=int, default=256, help='batch size') # all ops
# parser.add_argument('--batch_size', type=int, default=397, help='batch size') # 4 ops
# parser.add_argument('--batch_size', type=int, default=575, help='batch size') # 3 ops no conv
# parser.add_argument('--batch_size', type=int, default=680, help='batch size') # skip and none
# parser.add_argument('--batch_size', type=int, default=800, help='batch size') # skip
parser.add_argument('--batch_size', type=int, default=730, help='batch size') # skip
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=1, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
# parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--seed', type=int, default=3, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=6e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
args = parser.parse_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 256
  args.learning_rate_min = args.learning_rate_min * args.batch_size / 256
  args.arch_learning_rate = args.arch_learning_rate * args.batch_size / 256
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()

args.save = 'exps/search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

logging.info('save file: ' + args.save)
tensorbard_logger.add_text('save file', args.save)

CIFAR_CLASSES = 10
if args.set=='cifar100':
    CIFAR_CLASSES = 100

global_iter = 0

def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
  # torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  tensorbard_logger.add_text('gpu device', str(args.gpu))
  logging.info("args = %s", args)
  tensorbard_logger.add_text('args', str(args))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
  # utils.load(model, os.path.join('exps/search-EXP-20200605-165712/', 'weights.pt'))
  model = model.cuda()
  model = model.to(memory_format=torch.channels_last)
  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
  tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  train_transform, valid_transform = utils._data_transforms_cifar10(args)
  if args.set=='cifar100':
      train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
  else:
      train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

  num_train = len(train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

  architect = Architect(model, args)
  scaler = amp.GradScaler()

  for epoch in range(args.epochs):
    scheduler.step()
    lr = scheduler.get_lr()[0]
    logging.info('epoch %d lr %e', epoch, lr)
    tensorbard_logger.add_scalar('train/lr', lr, epoch)

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)
    visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=epoch)
    visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=epoch)

    print(F.softmax(model.alphas_normal, dim=-1))
    print(F.softmax(model.alphas_reduce, dim=-1))
    print(F.softmax(model.betas_normal[2:5], dim=-1))
    #model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    # training
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, scaler)
    # logging.info('train_acc %f', train_acc)
    # tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

    genotype = model.genotype()
    logging.info('genotype after updating = %s', genotype)
    visualize.plot(genotype.normal, "normal after updating", tensorboard_logger=tensorbard_logger, epoch=epoch)
    visualize.plot(genotype.reduce, "reduction after updating", tensorboard_logger=tensorbard_logger, epoch=epoch)

    # # validation
    # if args.epochs-epoch<=1:
    #   valid_acc, valid_obj = infer(valid_queue, model, criterion)
    #   logging.info('valid_acc %f', valid_acc)
    #   tensorbard_logger.add_scalar('valid/valid_acc', valid_acc, epoch)
    #
    # utils.save(model, os.path.join(args.save, 'weights.pt'))


def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, scaler):
  global global_iter
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()

  total_time = 0

  for step, (input, target) in enumerate(train_queue):
    model.train()
    # n = input.size(0)
    if step == 0:
        input = input.cuda()
        ori_input = input.to(memory_format=torch.channels_last)
        ori_target = target.cuda(non_blocking=True)

        # get a random minibatch from the search queue with replacement
        #input_search, target_search = next(iter(valid_queue))
        try:
          input_search, target_search = next(valid_queue_iter)
        except:
          valid_queue_iter = iter(valid_queue)
          input_search, target_search = next(valid_queue_iter)
        input_search = input_search.cuda()
        target_search = target_search.cuda(non_blocking=True)

    # if total_time == 0:
    #   print('Time start.')
    # time_start = time.time()

    # architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled, scaler=scaler)

    # time_end = time.time()

    # total_time = total_time + time_end - time_start

    input = ori_input
    target = ori_target
    optimizer.zero_grad()
    with autocast():
        logits = model(input)
        loss = criterion(logits, target)

    print(f'loss at step {step}: {loss}')

    scaler.scale(loss).backward()
    # loss.backward()
    # nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    # optimizer.step()
    scaler.step(optimizer)
    scaler.update()

    # prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    # objs.update(loss.data.item(), n)
    # top1.update(prec1.data.item(), n)
    # top5.update(prec5.data.item(), n)
    #
    # if step % args.report_freq == 0:
    #   logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
    #   tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
    #   tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
    #   tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)

    # global_iter = global_iter + 1

  print('time cost', total_time, 's')
  print('Time End.')
  return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.eval()

  with torch.no_grad():
    for step, (input, target) in enumerate(valid_queue):
      input = input.cuda()
      target = target.cuda(non_blocking=True)
      logits = model(input)
      loss = criterion(logits, target)

      prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
      n = input.size(0)
      objs.update(loss.data.item(), n)
      top1.update(prec1.data.item(), n)
      top5.update(prec5.data.item(), n)

      if step % args.report_freq == 0:
        logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
        tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
        tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)

  return top1.avg, objs.avg


if __name__ == '__main__':
  main()

