import os
import sys
import numpy as np
import time
import torch
import utils
import glob
import random
import logging
import argparse
import torch.nn as nn
import genotypes
from genotypes import Genotype
import torch.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

from torch.autograd import Variable
from model import NetworkImageNet as Network

import visualize
import torch_optimizer as optim

import random
import torch.cuda.amp as amp
import math
from imagenet import HybridTrainPipe, HybridValPipe
from base import DALIDataloader
import multiprocessing

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

IMAGENET_IMAGES_NUM_TRAIN = 1281167
IMAGENET_IMAGES_NUM_TEST = 50000


parser = argparse.ArgumentParser("training imagenet")
parser.add_argument('--workers', type=int, default=40, help='number of workers to load dataset') #32
# parser.add_argument('--batch_size', type=int, default=1024, help='bath size')
# parser.add_argument('--batch_size', type=int, default=300, help='batch size')   # 1x V100
# parser.add_argument('--batch_size', type=int, default=95, help='batch size')   # 1x 2080Ti
parser.add_argument('--batch_size', type=int, default=180, help='batch size')   # 1x 2080Ti
# parser.add_argument('--batch_size', type=int, default=650, help='batch size')   # 7x 2080Ti
# parser.add_argument('--batch_size', type=int, default=69, help='batch size')   # 1x 2080
parser.add_argument('--learning_rate', type=float, default=0.5, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
# parser.add_argument('--epochs', type=int, default=1200, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--save', type=str, default='/export/data/lwangcg/PC-DARTS/ImagenetExps/', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='PCDARTS', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
parser.add_argument('--lr_scheduler', type=str, default='linear', help='lr scheduler, linear or cosine')
# parser.add_argument('--tmp_data_dir', type=str, default='/dev/shm/ILSVRC2012/', help='temp data dir')
parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/tfrecord/', help='temp data dir')
parser.add_argument('--note', type=str, default='try', help='note for this run')
parser.add_argument('--optimizer', type=str, default='ranger', help='radam, ranger, SGDM')
parser.add_argument('--dali', type=bool, default=False, help='use dali')


args, unparsed = parser.parse_known_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 1024
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()
num_core = multiprocessing.cpu_count()

class CrossEntropyLabelSmooth(nn.Module):

    def __init__(self, num_classes, epsilon):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (-targets * log_probs).mean(0).sum()
        return loss

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def main(rank, world_size):
    setup(rank, world_size)

    args.save = '{}eval-{}-{}'.format(args.save, args.note, time.strftime("%Y%m%d-%H%M%S"))
    if rank == 0:
        utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        tensorbard_logger = SummaryWriter(args.save + '/tb')

    CLASSES = 1000
    global_iter = 0

    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    if rank == 0:
        logging.info("args = %s", args)
        logging.info("unparsed_args = %s", unparsed)
        tensorbard_logger.add_text('args', str(args))
        tensorbard_logger.add_text('unparsed_args', str(args))
    num_gpus = torch.cuda.device_count()
    geno_script = \
        "Genotype(normal=[('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 3), ('skip_connect', 2), ('skip_connect', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 0), ('skip_connect', 3)], reduce_concat=range(2, 6))"
    geno_script = geno_script.replace('skip_connect', 'sep_conv_3x3')
    # genotype = eval("genotypes.%s" % args.arch)
    genotype = eval(geno_script)
    if rank == 0:
        print('---------Genotype---------')
        logging.info(genotype)
        print('--------------------------')
        visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=0)
        visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=0)

    data_dir = args.tmp_data_dir
    # data_dir = os.path.join(args.tmp_data_dir, 'imagenet')
    traindir = os.path.join(data_dir, 'train')
    validdir = os.path.join(data_dir, 'val')
    trainindexdir = '/export/data/lwangcg/ILSVRC2012/idx_files_train'
    valindexdir = '/export/data/lwangcg/ILSVRC2012/idx_files_val'
    if args.dali:
        for root, dirs, files in os.walk(traindir):
            if root == traindir:
                train_filenames = [os.path.join(root, file) for file in files]
        for root, dirs, files in os.walk(validdir):
            if root == validdir:
                val_filenames = [os.path.join(root, file) for file in files]
        for root, dirs, files in os.walk(trainindexdir):
            if root == trainindexdir:
                train_idx_filenames = [os.path.join(root, file) for file in files]
        for root, dirs, files in os.walk(valindexdir):
            if root == valindexdir:
                val_idx_filenames = [os.path.join(root, file) for file in files]
        pip_train = HybridTrainPipe(batch_size=args.batch_size, num_threads=1, device_id=0,
                                    data_dir=traindir, crop=224, world_size=1, local_rank=0,
                                    train_filenames=train_filenames, train_idx_filenames=train_idx_filenames)
        pip_val = HybridValPipe(batch_size=args.batch_size, num_threads=1, device_id=0, data_dir=validdir,
                                crop=224, size=256, world_size=1, local_rank=0, val_filenames=val_filenames,
                                val_idx_filenames=val_idx_filenames)
        train_queue = DALIDataloader(pipeline=pip_train, size=IMAGENET_IMAGES_NUM_TRAIN, batch_size=args.batch_size,
                                     onehot_label=True)
        valid_queue = DALIDataloader(pipeline=pip_val, size=IMAGENET_IMAGES_NUM_TEST, batch_size=args.batch_size,
                                     onehot_label=True)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        train_data = dset.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4,
                    contrast=0.4,
                    saturation=0.4,
                    hue=0.2),
                transforms.ToTensor(),
                normalize,
            ]))
        valid_data = dset.ImageFolder(
            validdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
    best_acc_top1 = 0
    best_acc_top1_epoch = 0
    best_acc_top5 = 0
    best_acc_top5_epoch = 0
    lr = args.learning_rate
    scaler = amp.GradScaler()

    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype).to(rank)
    model = DDP(model, device_ids=[rank])
    if rank == 0:
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

    criterion = nn.CrossEntropyLoss()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)

    if args.optimizer == 'radam':
        optimizer = optim.RAdam(
            model.parameters(),
            1e-3,
            betas=(0.9, 0.999),
            weight_decay=args.weight_decay
        )
    elif args.optimizer == 'ranger':
        optimizer = optim.Ranger(
            model.parameters(),
            lr=1e-3,
            alpha=0.5,
            k=6,
            N_sma_threshhold=5,
            betas=(.95, 0.999),
            eps=1e-5,
            weight_decay=0
        )
        optimizer.step()
    elif args.optimizer == 'SGDM':
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay
            )

#    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

    for epoch in range(args.epochs):
        if args.lr_scheduler == 'cosine':
            scheduler.step()
            current_lr = scheduler.get_lr()[0]
        elif args.lr_scheduler == 'linear':
            current_lr = adjust_lr(optimizer, epoch)
        else:
            print('Wrong lr type, exit')
            sys.exit(1)
        if rank == 0:
            logging.info('Epoch: %d lr %e', epoch, current_lr)
            tensorbard_logger.add_scalar('train/lr', current_lr, epoch)
        if args.optimizer == 'SGDM':
            if epoch < 5 and args.batch_size > 256:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * (epoch + 1) / 5.0
                if rank == 0:
                    logging.info('Warming-up Epoch: %d, LR: %e', epoch, lr * (epoch + 1) / 5.0)
                    tensorbard_logger.add_scalar('train/Warming_up_lr', lr * (epoch + 1) / 5.0, epoch)
        if num_gpus > 1:
            model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        else:
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        if rank == 0:
            epoch_start = time.time()
        train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer, scaler, rank)
        if rank == 0:
            logging.info('Train_acc: %f', train_acc)
            tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)
        valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion, scaler, rank)
        if rank == 0:
            logging.info('Valid_acc_top1: %f', valid_acc_top1)
            logging.info('Valid_acc_top5: %f', valid_acc_top5)
            tensorbard_logger.add_scalar('valid/valid_acc_top1', valid_acc_top1, epoch)
            tensorbard_logger.add_scalar('valid/valid_acc_top5', valid_acc_top5, epoch)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds.', epoch_duration)
            tensorbard_logger.add_scalar('train/epoch_duration(s)', epoch_duration, epoch)
            is_best = False
            if valid_acc_top5 > best_acc_top5:
                best_acc_top5 = valid_acc_top5
                best_acc_top5_epoch = epoch
            if valid_acc_top1 > best_acc_top1:
                best_acc_top1 = valid_acc_top1
                best_acc_top1_epoch = epoch
                is_best = True
            utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc_top1': best_acc_top1,
                'optimizer' : optimizer.state_dict(),
                }, is_best, args.save)
    if rank == 0:
        logging.info('Best Top1 Accuracy: %d, at epoch %d', best_acc_top1, best_acc_top1_epoch)
        logging.info('Best Top5 Accuracy: %d, at epoch %d', best_acc_top5, best_acc_top5_epoch)
        tensorbard_logger.add_text('Best Top1 Accuracy: {}, at epoch {}'.format(best_acc_top1, best_acc_top1_epoch))
        tensorbard_logger.add_text('Best Top5 Accuracy: {}, at epoch {}'.format(best_acc_top5, best_acc_top5_epoch))
    cleanup()

def adjust_lr(optimizer, epoch):
    if args.optimizer == 'ranger':
        if epoch <= args.epochs * 0.75:
            lr = 1e-3
        else:
            lr = 1e-3 * (1 + math.cos(math.pi * (epoch - args.epochs * 0.75) / args.epochs)) / 2
    else:
        # Smaller slope for the last 5 epochs because lr * 1/250 is relatively large
        if args.epochs -  epoch > 5:
            lr = args.learning_rate * (args.epochs - 5 - epoch) / (args.epochs - 5)
        else:
            lr = args.learning_rate * (args.epochs - epoch) / ((args.epochs - 5) * 5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr        

def train(train_queue, model, criterion, optimizer, scaler, rank):
    global global_iter
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    batch_time = utils.AvgrageMeter()
    model.train()

    for step, (input, target) in enumerate(train_queue):
        target = target.to(rank, non_blocking=True)
        input = input.to(rank, non_blocking=True)
        if rank == 0:
            b_start = time.time()
        optimizer.zero_grad()
        with amp.autocast():
            logits, logits_aux = model(input)
            loss = criterion(logits, target)
            if args.auxiliary:
                loss_aux = criterion(logits_aux, target)
                loss += args.auxiliary_weight*loss_aux

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        scaler.step(optimizer)
        batch_time.update(time.time() - b_start)
        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)
        if rank == 0:
            if step % args.report_freq == 0:
                end_time = time.time()
                if step == 0:
                    duration = 0
                    start_time = time.time()
                else:
                    duration = end_time - start_time
                    start_time = time.time()
                logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs',
                                        step, objs.avg, top1.avg, top5.avg, duration, batch_time.avg)
                tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
                tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
                tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)
                tensorbard_logger.add_scalar('train/train_duration', duration, global_iter)
                tensorbard_logger.add_scalar('train/train_batch_time', batch_time.avg, global_iter)

        scaler.update()
        global_iter = global_iter + 1


    return top1.avg, objs.avg


def infer(valid_queue, model, criterion, scaler, rank):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    for step, (input, target) in enumerate(valid_queue):
        input = input.to(rank, non_blocking=True)
        target = target.to(rank, non_blocking=True)
        with amp.autocast():
            with torch.no_grad():
                logits, _ = model(input)
                loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)
        if rank == 0:
            if step % args.report_freq == 0:
                end_time = time.time()
                if step == 0:
                    duration = 0
                    start_time = time.time()
                else:
                    duration = end_time - start_time
                    start_time = time.time()
                logging.info('VALID Step: %03d Objs: %e R1: %f R5: %f Duration: %ds', step, objs.avg, top1.avg, top5.avg, duration)
                tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
                tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
                tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)
                tensorbard_logger.add_scalar('valid/valid_duration_', duration, global_iter)

    return top1.avg, top5.avg, objs.avg


if __name__ == '__main__':
    n_gpus = torch.cuda.device_count()
    mp.spawn(main,
             args=(n_gpus,),
             nprocs=n_gpus,
             join=True)
