import os
import sys
import numpy as np
import time
import torch
import utils
import glob
import random
import logging
import argparse
import torch.nn as nn
import genotypes
from genotypes import Genotype
import torch.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

from torch.autograd import Variable
from model import NetworkImageNet as Network

import visualize
import torch_optimizer as optim

import random
import torch.cuda.amp as amp
import math
from imagenet import HybridTrainPipe, HybridValPipe
from base import DALIDataloader
import multiprocessing
from thop import profile, clever_format
from ptflops import get_model_complexity_info

IMAGENET_IMAGES_NUM_TRAIN = 1281167
IMAGENET_IMAGES_NUM_TEST = 50000


parser = argparse.ArgumentParser("training imagenet")
parser.add_argument('--workers', type=int, default=72, help='number of workers to load dataset') #32
# parser.add_argument('--workers', type=int, default=32, help='number of workers to load dataset') #32
# parser.add_argument('--workers', type=int, default=28, help='number of workers to load dataset') #32
# parser.add_argument('--workers', type=int, default=40, help='number of workers to load dataset') #32
parser.add_argument('--batch_size', type=int, default=1024, help='bath size')
# parser.add_argument('--batch_size', type=int, default=700, help='batch size')   # 1x V100
# parser.add_argument('--batch_size', type=int, default=2200, help='batch size')   # 4x V100
# parser.add_argument('--batch_size', type=int, default=95, help='batch size')   # 1x 2080Ti
# parser.add_argument('--batch_size', type=int, default=460, help='batch size')   # 2x 2080Ti
# parser.add_argument('--batch_size', type=int, default=900, help='batch size')   # 4x 2080Ti
# parser.add_argument('--batch_size', type=int, default=230, help='batch size')   # 1x 2080Ti
# parser.add_argument('--batch_size', type=int, default=1200, help='batch size')   # 7x 2080Ti
# parser.add_argument('--batch_size', type=int, default=1024, help='batch size')   # 7x 2080Ti
# parser.add_argument('--batch_size', type=int, default=170, help='batch size')   # 1x 2080
# parser.add_argument('--batch_size', type=int, default=330, help='batch size')   # 2x 2080
# parser.add_argument('--batch_size', type=int, default=840, help='batch size')   # 7x 2080
# parser.add_argument('--batch_size', type=int, default=950, help='batch size')   # 8x 2080
parser.add_argument('--learning_rate', type=float, default=0.5, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
# parser.add_argument('--epochs', type=int, default=40, help='num of training epochs')
# parser.add_argument('--init_channels', type=int, default=38, help='num of init channels')
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
# parser.add_argument('--layers', type=int, default=18, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--save', type=str, default='/export/data/lwangcg/PC-DARTS/ImagenetExps/', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='PCDARTS', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
parser.add_argument('--lr_scheduler', type=str, default='linear', help='lr scheduler, linear or cosine')
parser.add_argument('--tmp_data_dir', type=str, default='/dev/shm/ILSVRC2012/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/data/lwangcg/ILSVRC2012/', help='temp data dir')
# parser.add_argument('--tmp_data_dir', type=str, default='/export/data/lwangcg/ILSVRC2012/tfrecord/', help='temp data dir')
parser.add_argument('--note', type=str, default='try', help='note for this run')
parser.add_argument('--optimizer', type=str, default='SGDM', help='radam, ranger, SGDM')
parser.add_argument('--dali', type=bool, default=False, help='use dali')



args, unparsed = parser.parse_known_args()

def adapt_batch_size():
  args.learning_rate = args.learning_rate * args.batch_size / 1024
  # args.epochs = int(args.epochs * args.batch_size / 256)

adapt_batch_size()
num_core = multiprocessing.cpu_count()

args.save = '{}eval-{}-{}'.format(args.save, args.note, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

tensorbard_logger = SummaryWriter(args.save+'/tb')

CLASSES = 1000
global_iter = 0

class CrossEntropyLabelSmooth(nn.Module):

    def __init__(self, num_classes, epsilon):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (-targets * log_probs).mean(0).sum()
        return loss

def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)
    tensorbard_logger.add_text('args', str(args))
    tensorbard_logger.add_text('unparsed_args', str(args))
    num_gpus = torch.cuda.device_count()
    geno_script = \
    "Genotype(normal=[('sep_conv_5x5', 0), ('dil_conv_5x5', 1), ('dil_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0)], normal_concat=range(2, 6), reduce=[('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('max_pool_3x3', 1), ('max_pool_3x3', 2), ('sep_conv_3x3', 1), ('dil_conv_5x5', 3), ('max_pool_3x3', 1), ('dil_conv_3x3', 0)], reduce_concat=range(2, 6))"
    # geno_script = geno_script.replace('skip_connect', 'sep_conv_3x3')
    genotype = eval(geno_script)
    # genotype = eval("genotypes.%s" % args.arch)
    print('---------Genotype---------')
    logging.info(genotype)
    print('--------------------------')
    visualize.plot(genotype.normal, "normal", tensorboard_logger=tensorbard_logger, epoch=0)
    visualize.plot(genotype.reduce, "reduction", tensorboard_logger=tensorbard_logger, epoch=0)
    model_tmp = Network(args.init_channels, CLASSES, args.layers, False, genotype)
    model_tmp.drop_path_prob = 0
    # input = torch.randn(1, 3, 224, 224)
    # macs, params = profile(model, inputs=(input,))
    # macs, params = clever_format([macs, params], "%.3f")
    macs, params = get_model_complexity_info(model_tmp, (3, 224, 224), as_strings=False,
                                             print_per_layer_stat=False, verbose=True)
    del model_tmp
    logging.info(f'macs: {float(macs)/1e6} M, params: {float(params/1e6)} M')
    tensorbard_logger.add_text('macs', f'{float(macs)/1e6} M')
    tensorbard_logger.add_text('params', f'{float(params/1e6)} M')
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
    # model = model.to(memory_format=torch.channels_last)
    if num_gpus > 1:
        model = nn.DataParallel(model)
        model = model.cuda()
        # model = model.to(memory_format=torch.channels_last)
    else:
        model = model.cuda()
        # model = model.to(memory_format=torch.channels_last)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    tensorbard_logger.add_text('param size (MB)', str(utils.count_parameters_in_MB(model)))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    if args.optimizer == 'radam':
        optimizer = optim.RAdam(
            model.parameters(),
            1e-3,
            betas=(0.9, 0.999),
            weight_decay=args.weight_decay
        )
    elif args.optimizer == 'ranger':
        optimizer = optim.Ranger(
            model.parameters(),
            lr=1e-3 * args.batch_size / 100,
            alpha=0.5,
            k=6,
            N_sma_threshhold=5,
            betas=(.95, 0.999),
            eps=1e-5,
            weight_decay=0
        )
        optimizer.step()
    elif args.optimizer == 'SGDM':
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay
            )
    data_dir = args.tmp_data_dir
    # data_dir = os.path.join(args.tmp_data_dir, 'imagenet')
    traindir = os.path.join(data_dir, 'train')
    validdir = os.path.join(data_dir, 'val')
    trainindexdir = '/export/data/lwangcg/ILSVRC2012/idx_files_train'
    valindexdir = '/export/data/lwangcg/ILSVRC2012/idx_files_val'
    if args.dali:
        for root, dirs, files in os.walk(traindir):
            if root == traindir:
                train_filenames = [os.path.join(root, file) for file in files]
        for root, dirs, files in os.walk(validdir):
            if root == validdir:
                val_filenames = [os.path.join(root, file) for file in files]
        for root, dirs, files in os.walk(trainindexdir):
            if root == trainindexdir:
                train_idx_filenames = [os.path.join(root, file) for file in files]
        for root, dirs, files in os.walk(valindexdir):
            if root == valindexdir:
                val_idx_filenames = [os.path.join(root, file) for file in files]
        pip_train = HybridTrainPipe(batch_size=args.batch_size, num_threads=1, device_id=0,
                                    data_dir=traindir, crop=224, world_size=1, local_rank=0, train_filenames=train_filenames, train_idx_filenames=train_idx_filenames)
        pip_val = HybridValPipe(batch_size=args.batch_size, num_threads=1, device_id=0, data_dir=validdir,
                                 crop=224, size=256, world_size=1, local_rank=0, val_filenames=val_filenames, val_idx_filenames=val_idx_filenames)
        train_queue = DALIDataloader(pipeline=pip_train, size=IMAGENET_IMAGES_NUM_TRAIN, batch_size=args.batch_size,
                                      onehot_label=True)
        valid_queue = DALIDataloader(pipeline=pip_val, size=IMAGENET_IMAGES_NUM_TEST, batch_size=args.batch_size,
                                     onehot_label=True)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        train_data = dset.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4,
                    contrast=0.4,
                    saturation=0.4,
                    hue=0.2),
                transforms.ToTensor(),
                normalize,
            ]))
        valid_data = dset.ImageFolder(
            validdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)

#    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
    best_acc_top1 = 0
    best_acc_top1_epoch = 0
    best_acc_top5 = 0
    best_acc_top5_epoch = 0
    lr = args.learning_rate
    scaler = amp.GradScaler()
    acc_list = []
    optimizer.epoch2change = False
    optimizer.change_epoch = 1e9
    optimizer.epoch2end = 1e9

    current_epoch = 0

    # checkpoint = torch.load('/export/data/lwangcg/PC-DARTS/ImagenetExps/eval-try-20200914-115707/checkpoint.pth.tar')
    # model.load_state_dict(checkpoint['state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer'])
    # current_epoch = max(checkpoint['epoch'], 0)
    # best_acc_top1 = checkpoint['best_acc_top1']
    # acc_list = [best_acc_top1] * current_epoch
    # optimizer.epoch2change = True
    # optimizer.change_epoch = current_epoch
    # optimizer.epoch2end = 349 # 249
    # args.epochs = 349
    # logging.info(f'Epoch to change: {optimizer.change_epoch}')
    # tensorbard_logger.add_text('epoch_change', f'Epoch to change: {optimizer.change_epoch}')

    for epoch in range(current_epoch, args.epochs):
        if epoch > optimizer.epoch2end:
            break
        if args.lr_scheduler == 'cosine':
            scheduler.step()
            current_lr = scheduler.get_lr()[0]
        elif args.lr_scheduler == 'linear':
            current_lr = adjust_lr(optimizer, epoch)
        else:
            print('Wrong lr type, exit')
            sys.exit(1)
        logging.info('Epoch: %d lr %e', epoch, current_lr)
        tensorbard_logger.add_scalar('train/lr', current_lr, epoch)
        if args.optimizer == 'SGDM':
            if epoch < 5 and args.batch_size > 256:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * (epoch + 1) / 5.0
                logging.info('Warming-up Epoch: %d, LR: %e', epoch, lr * (epoch + 1) / 5.0)
                tensorbard_logger.add_scalar('train/Warming_up_lr', lr * (epoch + 1) / 5.0, epoch)
        if num_gpus > 1:
            model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        else:
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        epoch_start = time.time()
        train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer, scaler)
        logging.info('Train_acc: %f', train_acc)
        tensorbard_logger.add_scalar('train/train_acc', train_acc, epoch)

        valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion, scaler)
        logging.info('Valid_acc_top1: %f', valid_acc_top1)
        logging.info('Valid_acc_top5: %f', valid_acc_top5)
        tensorbard_logger.add_scalar('valid/valid_acc_top1', valid_acc_top1, epoch)
        tensorbard_logger.add_scalar('valid/valid_acc_top5', valid_acc_top5, epoch)
        epoch_duration = time.time() - epoch_start
        logging.info('Epoch time: %ds.', epoch_duration)
        tensorbard_logger.add_scalar('train/epoch_duration(s)', epoch_duration, epoch)
        is_best = False
        if valid_acc_top5 > best_acc_top5:
            best_acc_top5 = valid_acc_top5
            best_acc_top5_epoch = epoch
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            best_acc_top1_epoch = epoch
            is_best = True
        utils.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_acc_top1': best_acc_top1,
            'optimizer' : optimizer.state_dict(),
            }, is_best, args.save)
        acc_list.append(valid_acc_top1)
        # if (epoch>=10 and acc_list[epoch] < min(acc_list[epoch-10: epoch]) * 0.9 and not optimizer.epoch2change) or epoch>=78:
        # if epoch >= 78:
        #     optimizer.epoch2change = True
        #     optimizer.change_epoch = epoch
        #     optimizer.epoch2end = math.ceil(epoch * 4 / 3) + 1
        #     logging.info(f'Epoch to change: {optimizer.change_epoch}')
        #     tensorbard_logger.add_text('epoch_change', f'Epoch to change: {optimizer.change_epoch}')


    logging.info('Best Top1 Accuracy: %f, at epoch %d', best_acc_top1, best_acc_top1_epoch)
    logging.info('Best Top5 Accuracy: %f, at epoch %d', best_acc_top5, best_acc_top5_epoch)
    tensorbard_logger.add_text('Best_Top1_Accuracy', 'Best_Top1_Accuracy_{}_at_epoch_{}'.format(best_acc_top1, best_acc_top1_epoch))
    tensorbard_logger.add_text('Best_Top5_Accuracy', 'Best_Top5_Accuracy_{}_at_epoch_{}'.format(best_acc_top5, best_acc_top5_epoch))
        
def adjust_lr(optimizer, epoch):
    if args.optimizer == 'ranger':
        if not optimizer.epoch2change:
            lr = 1e-3 * args.batch_size / 100
        else:
            lr = 1e-3 * args.batch_size / 100 * (1 + math.cos(math.pi * (epoch - optimizer.change_epoch) / (optimizer.epoch2end - optimizer.change_epoch))) / 2
    else:
        # Smaller slope for the last 5 epochs because lr * 1/250 is relatively large
        if args.epochs -  epoch > 5:
            lr = args.learning_rate * (args.epochs - 5 - epoch) / (args.epochs - 5)
        else:
            lr = args.learning_rate * (args.epochs - epoch) / ((args.epochs - 5) * 5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr        

def train(train_queue, model, criterion, optimizer, scaler):
    global global_iter
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    batch_time = utils.AvgrageMeter()
    model.train()

    for step, (input, target) in enumerate(train_queue):
        target = target.cuda(non_blocking=True)
        input = input.cuda(non_blocking=True)
        b_start = time.time()
        optimizer.zero_grad()
        with amp.autocast():
            # input = input.to(memory_format=torch.channels_last)
            logits, logits_aux = model(input)
            loss = criterion(logits, target)
            if args.auxiliary:
                loss_aux = criterion(logits_aux, target)
                loss += args.auxiliary_weight*loss_aux
        scaler.scale(loss).backward()
        # scaler.unscale_(optimizer)
        # nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        scaler.step(optimizer)
        scaler.update()
        batch_time.update(time.time() - b_start)
        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            end_time = time.time()
            if step == 0:
                duration = 0
                start_time = time.time()
            else:
                duration = end_time - start_time
                start_time = time.time()
            logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs', 
                                    step, objs.avg, top1.avg, top5.avg, duration, batch_time.avg)
            tensorbard_logger.add_scalar('train/train_objs', objs.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_top1', top1.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_top5', top5.avg, global_iter)
            tensorbard_logger.add_scalar('train/train_duration', duration, global_iter)
            tensorbard_logger.add_scalar('train/train_batch_time', batch_time.avg, global_iter)

        global_iter = global_iter + 1


    return top1.avg, objs.avg


def infer(valid_queue, model, criterion, scaler):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    for step, (input, target) in enumerate(valid_queue):
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        with amp.autocast():
            with torch.no_grad():
                # input = input.to(memory_format=torch.channels_last)
                logits, _ = model(input)
                loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            end_time = time.time()
            if step == 0:
                duration = 0
                start_time = time.time()
            else:
                duration = end_time - start_time
                start_time = time.time()
            logging.info('VALID Step: %03d Objs: %e R1: %f R5: %f Duration: %ds', step, objs.avg, top1.avg, top5.avg, duration)
            tensorbard_logger.add_scalar('valid/valid_objs', objs.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_top1', top1.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_top5', top5.avg, global_iter)
            tensorbard_logger.add_scalar('valid/valid_duration_', duration, global_iter)

    return top1.avg, top5.avg, objs.avg


if __name__ == '__main__':
    main() 
