import copy
import gc
import os
import warnings

import argparse
import collections
import shutil
import sys
import time

import numpy as np
import setproctitle
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from datetime import datetime


# import models
import dist_utils, resnet, vgg, experimental_utils, dataloader
from lbfgs import LBFGSOptimizer
import kfac
from sgd import SGDOptimizer
from fp16util import *
from logger import TensorboardLogger, FileLogger
from meter import AverageMeter, NetworkMeter, TimeMeter
from utils.model_utils import get_grad_norm, get_param_norm

from ptflops import get_model_complexity_info

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--dataset', type=str, default='imagenet', choices=['cifar10','cifar100','imagenet'],
                        help='name of dataset')
    parser.add_argument('--data', metavar='DIR', help='path to dataset')
    parser.add_argument('--model', type=str, default='resnet50', choices=['resnet50', 'vgg11', 'vgg11_bn'])
    parser.add_argument('--phases', type=str,
                        help='Specify epoch order of data resize and learning rate schedule: [{"ep":0,"sz":128,"bs":64},{"ep":5,"lr":1e-2}]')
    # parser.add_argument('--save-dir', type=str, default=Path.cwd(), help='Directory to save logs and models.')
    parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'LBFGS', 'KFAC'],
                        help='optimization method')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--decay-period', default=10, type=int, help='lr rate decay period')
    parser.add_argument('--lr-decay', nargs='+', type=int, default=[25,35,40,45,50],
                        help='epoch intervals to decay lr (default: 25,35,40,45,50)')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--max-epoch', default=100, type=int, help='max epochs')
    parser.add_argument('--init-bn0', action='store_true', help='Intialize running batch norm mean to 0')
    parser.add_argument('--print-freq', '-p', default=20, type=int,
                        metavar='N', help='log/print every this many steps (default: 5)')
    parser.add_argument('--no-bn-wd', action='store_true', help='Remove batch norm from weight decay')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode. Default True')
    parser.add_argument('--loss-scale', type=float, default=1024,
                        help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
    parser.add_argument('--distributed', action='store_true', help='Run distributed training. Default True')
    parser.add_argument('--dist-url', default='env://', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=0, type=int,
                        help='Used for multi-process training. Can either be manually set ' +
                             'or automatically set by using \'python -m multiproc\'.')
    parser.add_argument('--logdir', default='', type=str,
                        help='where logs go')
    parser.add_argument('--skip-auto-shutdown', action='store_true',
                        help='Shutdown instance at the end of training or failure')
    parser.add_argument('--auto-shutdown-success-delay-mins', default=10, type=int,
                        help='how long to wait until shutting down on success')
    parser.add_argument('--auto-shutdown-failure-delay-mins', default=60, type=int,
                        help='how long to wait before shutting down on error')

    parser.add_argument('--short-epoch', action='store_true',
                        help='make epochs short (for debugging)')

    parser.add_argument('--grad-clip', default = 0.05, type = float,
                        help = 'gradient clipping')
    # LBFGS hyper parameters
    parser.add_argument('--stat-decay-param', default = 0.9, type = float,
                        help = 'stat decay for parameters')
    parser.add_argument('--stat-decay-grad', default = 0.9, type = float,
                        help = 'stat decay for gradients')
    parser.add_argument('--update-freq', default = 200, type = int, 
                        help = 'update frequency for Hessian approximation')
    parser.add_argument('--history-size', default = 20, type = int,
                        help = 'hisotry size for LBFGS-related vectors')
    parser.add_argument('--lbfgs-damping', default = 0.2, type = float,
                        help = 'LBFGS damping factor')

    # KFAC hyper parameters
    parser.add_argument('--kfac-update-freq', type=int, default=100,
                        help='iters between kfac inv ops (0 disables kfac) (default: 100)')
    parser.add_argument('--kfac-cov-update-freq', type=int, default=10,
                        help='iters between kfac cov ops (default: 10)')
    parser.add_argument('--kfac-update-freq-alpha', type=float, default=10,
                        help='KFAC update freq multiplier (default: 10)')
    parser.add_argument('--kfac-update-freq-decay', nargs='+', type=int, default=None,
                        help='KFAC update freq decay schedule (default None)')
    parser.add_argument('--use-inv-kfac', action='store_true', default=False,
                        help='Use inverse KFAC update instead of eigen (default False)')
    parser.add_argument('--stat-decay', type=float, default=0.95,
                        help='Alpha value for covariance accumulation (default: 0.95)')
    parser.add_argument('--damping', type=float, default=0.001,
                        help='KFAC damping factor (defaultL 0.001)')
    parser.add_argument('--damping-alpha', type=float, default=0.5,
                        help='KFAC damping decay factor (default: 0.5)')
    parser.add_argument('--damping-decay', nargs='+', type=int, default=None,
                        help='KFAC damping decay schedule (default None)')
    parser.add_argument('--kl-clip', type=float, default=0.001,
                        help='KL clip (default: 0.001)')
    parser.add_argument('--skip-layers', nargs='+', type=str, default=[],
                        help='Layer types to ignore registering with KFAC (default: [])')
    parser.add_argument('--coallocate-layer-factors', action='store_true', default=True,
                        help='Compute A and G for a single layer on the same worker. ')
    parser.add_argument('--kfac-comm-method', type=str, default='comm-opt',
                        help='KFAC communication optimization strategy. One of comm-opt, '
                             'mem-opt, or hybrid_opt. (default: comm-opt)')
    parser.add_argument('--kfac-grad-worker-fraction', type=float, default=0.25,
                        help='Fraction of workers to compute the gradients '
                             'when using HYBRID_OPT (default: 0.25)')
    return parser


cudnn.benchmark = True
args = get_parser().parse_args()
# os.sched_setaffinity(os.getpid(), {i for i in range(49, 97)})

# Only want master rank logging to tensorboard
is_master = (not args.distributed) or (dist_utils.env_rank() == 0)
is_rank0 = args.local_rank == 0
tb = TensorboardLogger(args.logdir, is_master=is_master)
log = FileLogger(args.logdir, is_master=is_master, is_rank0=is_rank0)


def main():

    os.system('shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', dist_utils.env_world_size())

    str_process_name = "ResNet50-ImageNet-" + args.optimizer + " (using DDP):" + str(args.local_rank)
    setproctitle.setproctitle(str_process_name)

    # need to index validation directory before we start counting the time
    # dataloader.sort_ar(args.data + '/val')

    if args.distributed:
        log.console('Distributed initializing process group. Rank = %d' % args.local_rank)
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=dist_utils.env_world_size())
        assert (dist_utils.env_world_size() == dist.get_world_size())
        log.console("Distributed: success (%d/%d)" % (args.local_rank, dist.get_world_size()))

    sz_img = 224
    nclass = 1000
    if args.dataset == 'cifar10' or args.dataset == 'cifar100':
        sz_img = 32
        nclass = 10 if args.dataset == 'cifar10' else 100

    log.console("Loading model")
    if 'resnet' in args.model:
        model = resnet.__dict__[args.model](bn0=args.init_bn0).cuda()
    elif 'vgg' in args.model:
        model = vgg.__dict__[args.model](nclass).cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    # calculate the parameter size
    flops, params = get_model_complexity_info(model, (3, sz_img, sz_img), as_strings=False,
                                              print_per_layer_stat=False, verbose=False)
    if args.local_rank == 0:
        print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
        print('{:<30}  {:<8}'.format('Number of parameters: ', params))

    # Give that allreduce requires 2 x bytes on the network compared to the value it operates on
    model_network_size = params * 4 * 2
    if args.local_rank == 0:
        print(model_network_size)
    log.verbose("  parameter_network_size: %f" % model_network_size)

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
        #model_params = master_params = model.parameters()
    else:
        model_params = master_params = model.parameters()

    #optim_params = experimental_utils.bnwd_optim_params(model, model_params,
    #                                                    master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    if args.optimizer == 'SGD':
        preconditioner = None
        optimizer = SGDOptimizer(master_params, lr=args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)  # start with 0 lr. Scheduler will change this later
    elif args.optimizer == 'LBFGS':
        preconditioner = None
        optimizer = LBFGSOptimizer(master_params, lr = args.lr, momentum = args.momentum,
                                weight_decay = args.weight_decay, mm_p = args.stat_decay_param,
                                mm_g = args.stat_decay_grad, update_freq = args.update_freq,
                                hist_sz = args.history_size, decay_period=args.decay_period,
                                damping = args.lbfgs_damping, kl_clip = args.grad_clip)
    elif args.optimizer == 'KFAC':
        if args.kfac_comm_method == 'comm-opt':
            comm_method=kfac.CommMethod.COMM_OPT
        elif args.kfac_comm_method == 'mem-opt':
            comm_method=kfac.CommMethod.MEM_OPT
        elif args.kfac_comm_method == 'hybrid-opt':
            comm_method=kfac.CommMethod.HYBRID_OPT
        else:
            raise ValueError('Unknwon KFAC Comm Method: {}'.format(
                    args.kfac_comm_method))
        preconditioner = kfac.KFAC(
            model, 
            damping=args.damping, 
            factor_decay=args.stat_decay,
            factor_update_freq=args.kfac_cov_update_freq,
            inv_update_freq=args.kfac_update_freq,
            kl_clip=args.kl_clip,
            lr=args.lr, 
            batch_first=True,
            comm_method=comm_method,
            distribute_layer_factors=not args.coallocate_layer_factors,
            grad_scaler=args.grad_scaler if 'grad_scaler' in args else None,
            grad_worker_fraction = args.kfac_grad_worker_fraction,
            skip_layers=args.skip_layers,
            use_eigen_decomp=not args.use_inv_kfac,
        )
        kfac_param_scheduler = kfac.KFACParamScheduler(
            preconditioner,
            damping_alpha=args.damping_alpha,
            damping_schedule=args.damping_decay,
            update_freq_alpha=args.kfac_update_freq_alpha,
            update_freq_schedule=args.kfac_update_freq_decay
        )
        optimizer = SGDOptimizer(
            master_params,
            lr=args.lr, 
            momentum=args.momentum,
            weight_decay=args.weight_decay
        )
    else:
        print('[ERROR] choose a valid optmizer!')
        quit()


    if args.resume:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    # save script so we can reproduce from logs
    shutil.copy2(os.path.realpath(__file__), f'{args.logdir}')

    log.console("Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)")
    phases = eval(args.phases)
    if args.dataset == 'imagenet':
        dm = DataManager_ImageNet([copy.deepcopy(p) for p in phases if 'bs' in p])
    elif args.dataset == 'cifar10' or args.dataset == 'cifar100':
        dm = DataManager_CIFAR([copy.deepcopy(p) for p in phases if 'bs' in p])
    # scheduler = Scheduler(optimizer, [copy.deepcopy(p) for p in phases if 'lr' in p])
    # lr_lambda = lambda epoch: 0.98
    # scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lr_lambda)
    lr_lambda = lambda epoch: 0.5 ** (epoch // args.decay_period) 
    # lr_lambda = create_lr_schedule(args.dist_backend.size(), args.lr_decay)
    # lr_decay = args.lr / args.max_epoch
    # dm.set_epoch( 0 )
    # iters_per_epoch = len(dm.trn_dl)
    #for i, (_, _) in enumerate(dm.trn_dl):
    #    iters_per_epoch = i + 1

    # print('[INFO] number of iterations per epoch: {}'.format(iters_per_epoch))
    # lr_lambda = lambda epoch: 1 / (1 + lr_decay * (epoch * iters_per_epoch))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        dm.set_epoch( 0 )
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    log.event("~~epoch\thours\ttop1\ttop5\n")
    dm.set_epoch( 0 )
    _,_,loss_init = validate(dm.val_dl, model, criterion, 0, start_time, istrain = False)
    wd_scaling = 1.0
    for epoch in range(0, args.max_epoch):
        dm.set_epoch(epoch)

        time_epoch_start = time.time()
        loss_i = train(dm.trn_dl, model, criterion, optimizer, preconditioner, scheduler, epoch, wd_scaling)
        log.verbose("#############################one epoch time cost: %f" % (time.time() - time_epoch_start))

        top1, top5, _ = validate(dm.val_dl, model, criterion, epoch, start_time, istrain = True)

        scheduler.step()
        if args.optimizer == 'KFAC':
            kfac_param_scheduler.step()

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        wd_scaling = loss_i / loss_init
        if args.local_rank == 0:
            if is_best:
                pass
                #save_checkpoint(epoch, model, best_top5, optimizer, is_best=True, filename='model_best.pth.tar')
            if epoch % 5 == 0:
                save_checkpoint(epoch, model, best_top5, optimizer, filename=args.logdir+f'/epoch{epoch}_checkpoint.tar')


def train(trn_loader, model, criterion, optimizer, preconditioner, scheduler, epoch, wd_scaling):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    start = time.time()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        input = input.cuda()
        target = target.cuda()
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        #scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        time_start_forward = time.time()
        output = model(input)
        time_end_forward = time.time()
        # for single machine training, this value is around 10ms
        #log.verbose("  forward time cost: %f" % (time_end_forward - time_start_forward))

        loss = criterion(output, target)

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            model.zero_grad()
            loss.backward()
            model_grads_to_master_grads(model_params, master_params)
            for param in master_params:
                param.grad.data = param.grad.data / args.loss_scale
            if args.optimizer == 'SGD' or args.optimizer == 'KFAC':
                optimizer.step()
            elif args.optimizer == 'LBFGS':
                optimizer.step( epoch = epoch, batch = i )
            master_params_to_model_params(model_params, master_params)
            loss = loss / args.loss_scale

            time_end_backward = time.time()
            # for single machine training, this value is 50ms, when training
            #log.verbose("backwards time cost: %f" % (time_end_backward - time_end_forward))
        else:
            optimizer.zero_grad()

            """
            if args.optimizer == 'KFAC' and preconditioner.param_groups[0]['step'] % preconditioner.param_groups[0]['factor_update_freq'] == 0:
                    optimizer.acc_stats = True
                    with torch.no_grad():
                        sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1),
                                              1).squeeze().cuda()
                    loss_sample = criterion(output, sampled_y)
                    loss_sample.backward(retain_graph=True)
                    optimizer.acc_stats = False
                    optimizer.zero_grad()  # clear the gradient for computing true-fisher.
            """

            # start to all_reduce with the order of bucket
            loss.backward()
            time_end_backward = time.time()
            #log.verbose("backwards time cost: %f" % (time_end_backward - time_end_forward))
            if args.optimizer == 'SGD':
                optimizer.step()
            elif args.optimizer == 'LBFGS':
                optimizer.step( epoch = epoch, batch = i)
            elif args.optimizer == 'KFAC':
                pn_kfac = get_param_norm(model)
                gn_before_kfac = get_grad_norm(model)
                preconditioner.step()
                gn_after_kfac = get_grad_norm(model)
                optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()
        corr1, corr5 = correct(output.data, target, topk=(1, 5))
        reduced_loss, batch_total = to_python_float(loss.data), to_python_float(input.size(0))

        # Must keep track of global batch size
        # since not all machines are guaranteed equal batches at the end of an epoch
        if args.distributed:
            metrics = torch.tensor([batch_total, reduced_loss, corr1, corr5]).float().cuda()
            batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
            reduced_loss = reduced_loss / dist_utils.env_world_size()
        top1acc = to_python_float(corr1) * (100.0 / batch_total)
        top5acc = to_python_float(corr5) * (100.0 / batch_total)

        losses.update(reduced_loss, batch_total)
        top1.update(top1acc, batch_total)
        top5.update(top5acc, batch_total)

        should_print = (batch_num % args.print_freq == 0) or (batch_num == len(trn_loader))
        if args.local_rank == 0 and should_print:
            #tb.log_memory()
            #tb.log_trn_times(timer.batch_time.val, timer.data_time.val, input.size(0))
            tb.log_trn_loss(losses.avg, top1.avg, top5.avg)

            #recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            #tb.log("sizes/batch_total", batch_total)
            #tb.log('net/recv_gbit', recv_gbit)
            #tb.log('net/transmit_gbit', transmit_gbit)

            output = (f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                      f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                      f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                      f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                      f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t')
                      #f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

        if args.optimizer == 'LBFGS' and optimizer.update_dg_dp and optimizer.start_lbfgs:
            tb.log_lbfgs(optimizer.rho_list[-1], optimizer.h0, optimizer.tao_before, optimizer.tao_after, optimizer.gn_before, optimizer.gn_after, optimizer.pn)
            #log.verbose(f'[LBFGS INFO] Step {optimizer.steps}: rho {optimizer.rho_list[-1]:.6f}(V7)')
        if args.optimizer == 'KFAC' and i % 100 == 0:
            tb.log_kfac(gn_before_kfac, gn_after_kfac, pn_kfac)
        if args.optimizer == 'SGD' and optimizer.steps % 100 == 0:
            tb.log_sgd(optimizer.gn, optimizer.pn)

        tb.update_step_count(batch_total)

    return losses.avg


def validate(val_loader, model, criterion, epoch, start_time, istrain = True):
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    eval_start_time = time.time()

    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.cuda()
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        if args.distributed:
            top1acc, top5acc, loss, batch_total = distributed_predict(input, target, model, criterion)
        else:
            with torch.no_grad():
                output = model(input)
                loss = criterion(output, target).data
            batch_total = input.size(0)
            top1acc, top5acc = accuracy(output.data, target, topk=(1, 5))

        # Eval batch done. Logging results
        timer.batch_end()
        losses.update(to_python_float(loss), to_python_float(batch_total))
        top1.update(to_python_float(top1acc), to_python_float(batch_total))
        top5.update(to_python_float(top5acc), to_python_float(batch_total))
        should_print = (batch_num % args.print_freq == 0) or (batch_num == len(val_loader))
        if args.local_rank == 0 and should_print:
            output = (f'Test:  [{epoch}][{batch_num}/{len(val_loader)}]\t'
                      f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                      f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                      f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})')
            log.verbose(output)

    if istrain:
        tb.log_eval(top1.avg, top5.avg, time.time() - eval_start_time)
        tb.log('epoch', epoch)

    return top1.avg, top5.avg, losses.avg


def distributed_predict(input, target, model, criterion):
    # Allows distributed prediction on uneven batches. Test set isn't always large enough for every GPU to get a batch
    batch_size = input.size(0)
    output = loss = corr1 = corr5 = valid_batches = 0

    if batch_size:
        with torch.no_grad():
            output = model(input)
            loss = criterion(output, target).data
        # measure accuracy and record loss
        valid_batches = 1
        corr1, corr5 = correct(output.data, target, topk=(1, 5))

    metrics = torch.tensor([batch_size, valid_batches, loss, corr1, corr5]).float().cuda()
    batch_total, valid_batches, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
    reduced_loss = reduced_loss / valid_batches

    top1 = corr1 * (100.0 / batch_total)
    top5 = corr5 * (100.0 / batch_total)
    return top1, top5, reduced_loss, batch_total


class DataManager_ImageNet():
    def __init__(self, phases):
        self.phases = self.preload_phase_data(phases)

    def set_epoch(self, epoch):
        cur_phase = self.get_phase(epoch)
        if cur_phase:
            self.set_data(cur_phase)
        if hasattr(self.trn_smp, 'set_epoch'):
            self.trn_smp.set_epoch(epoch)
        if hasattr(self.val_smp, 'set_epoch'):
            self.val_smp.set_epoch(epoch)

    def get_phase(self, epoch):
        return next((p for p in self.phases if p['ep'] == epoch), None)

    def set_data(self, phase):
        """Initializes data loader."""
        if phase.get('keep_dl', False):
            log.event(f'Batch size changed: {phase["bs"]}')
            tb.log_size(phase['bs'])
            self.trn_dl.update_batch_size(phase['bs'])
            return

        log.event(
            f'Dataset changed.\nImage size: {phase["sz"]}\nBatch size: {phase["bs"]}\nTrain Directory: {phase["trndir"]}\nValidation Directory: {phase["valdir"]}')
        tb.log_size(phase['bs'], phase['sz'])

        self.trn_dl, self.val_dl, self.trn_smp, self.val_smp = phase['data']
        self.phases.remove(phase)

        # clear memory before we begin training
        gc.collect()

    def preload_phase_data(self, phases):
        for phase in phases:
            if not phase.get('keep_dl', False):
                self.expand_directories(phase)
                phase['data'] = self.preload_data(**phase)
        return phases

    def expand_directories(self, phase):
        trndir = phase.get('trndir', '')
        valdir = phase.get('valdir', trndir)
        phase['trndir'] = args.data + trndir + '/train'
        phase['valdir'] = args.data + valdir + '/val'

    def preload_data(self, ep, sz, bs, trndir, valdir, **kwargs):  # dummy ep var to prevent error
        if 'lr' in kwargs: del kwargs['lr']  # in case we mix schedule and data phases
        """Pre-initializes data-loaders. Use set_data to start using it."""
        if sz == 128:
            val_bs = max(bs, 512)
        elif sz == 224:
            val_bs = max(bs, 256)
        else:
            val_bs = max(bs, 128)
        return dataloader.get_loaders(trndir, valdir, bs=bs, val_bs=val_bs, sz=sz, workers=args.workers,
                                      fp16=args.fp16, distributed=args.distributed, **kwargs)
class DataManager_CIFAR():
    def __init__(self, phases):
        print(phases[0])
        self.dataset = datasets.CIFAR10 if args.dataset=='cifar10' else datasets.CIFAR100
        self.preload_data(**phases[0])

    def set_epoch(self, epoch):
        pass

    def preload_data(self, ep, sz, bs, **kwargs):
        self.trn_dl = torch.utils.data.DataLoader(
            self.dataset(root=args.data, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]), download=True),
            batch_size=bs, shuffle=True,
            num_workers=args.workers, pin_memory=True)
        self.val_dl = torch.utils.data.DataLoader(
            self.dataset(root=args.data, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
        batch_size=bs, shuffle=False,
        num_workers=args.workers, pin_memory=True)

# ### Learning rate scheduler
def create_lr_schedule(workers, decay_schedule, alpha=0.1):
    def lr_schedule(epoch):
        lr_adj = 1.
        decay_schedule.sort(reverse=True)
        for e in decay_schedule:
            if epoch >= e:
                lr_adj *= alpha
        return lr_adj
    return lr_schedule
class Scheduler():
    def __init__(self, optimizer, phases):
        self.optimizer = optimizer
        self.current_lr = None
        self.phases = [self.format_phase(p) for p in phases]
        self.tot_epochs = max([max(p['ep']) for p in self.phases])

    def format_phase(self, phase):
        phase['ep'] = listify(phase['ep'])
        phase['lr'] = listify(phase['lr'])
        if len(phase['lr']) == 2:
            assert (len(phase['ep']) == 2), 'Linear learning rates must contain end epoch'
        return phase

    def linear_phase_lr(self, phase, epoch, batch_curr, batch_tot):
        lr_start, lr_end = phase['lr']
        ep_start, ep_end = phase['ep']
        if 'epoch_step' in phase: batch_curr = 0  # Optionally change learning rate through epoch step
        ep_relative = epoch - ep_start
        ep_tot = ep_end - ep_start
        return self.calc_linear_lr(lr_start, lr_end, ep_relative, batch_curr, ep_tot, batch_tot)

    def calc_linear_lr(self, lr_start, lr_end, epoch_curr, batch_curr, epoch_tot, batch_tot):
        step_tot = epoch_tot * batch_tot
        step_curr = epoch_curr * batch_tot + batch_curr
        step_size = (lr_end - lr_start) / step_tot
        return lr_start + step_curr * step_size

    def get_current_phase(self, epoch):
        for phase in reversed(self.phases):
            if (epoch >= phase['ep'][0]): return phase
        raise Exception('Epoch out of range')

    def get_lr(self, epoch, batch_curr, batch_tot):
        phase = self.get_current_phase(epoch)
        if len(phase['lr']) == 1: return phase['lr'][0]  # constant learning rate
        return self.linear_phase_lr(phase, epoch, batch_curr, batch_tot)

    def update_lr(self, epoch, batch_num, batch_tot):
        lr = self.get_lr(epoch, batch_num, batch_tot)
        if self.current_lr == lr: return
        if ((batch_num == 1) or (batch_num == batch_tot)):
            log.event(f'Changing LR from {self.current_lr} to {lr}')

        self.current_lr = lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        tb.log("sizes/lr", lr)
        tb.log("sizes/momentum", args.momentum)


# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
    if isinstance(t, (float, int)): return t
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]


def save_checkpoint(epoch, model, best_top5, optimizer, is_best=False, filename='checkpoint.pth.tar'):
    state = {
        'epoch': epoch + 1, 'state_dict': model.state_dict(),
        'best_top5': best_top5, 'optimizer': optimizer.state_dict(),
    }
    torch.save(state, filename)
    if is_best: shutil.copyfile(filename, f'{args.logdir}/{filename}')


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy@k for the specified values of k"""
    corrrect_ks = correct(output, target, topk)
    batch_size = target.size(0)
    return [correct_k.float().mul_(100.0 / batch_size) for correct_k in corrrect_ks]


def correct(output, target, topk=(1,)):
    """Computes the accuracy@k for the specified values of k"""
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).sum(0, keepdim=True)
        res.append(correct_k)
    return res


def listify(p=None, q=None):
    if p is None:
        p = []
    elif not isinstance(p, collections.Iterable):
        p = [p]
    n = q if type(q) == int else 1 if q is None else len(q)
    if len(p) == 1: p = p * n
    return p


if __name__ == '__main__':
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            main()
        if not args.skip_auto_shutdown: os.system(f'sudo shutdown -h -P +{args.auto_shutdown_success_delay_mins}')
    except Exception as e:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        import traceback

        traceback.print_tb(exc_traceback, file=sys.stdout)
        log.event(e)
        # in case of exception, wait 2 hours before shutting down
        if not args.skip_auto_shutdown: os.system(f'sudo shutdown -h -P +{args.auto_shutdown_failure_delay_mins}')
    tb.close()
