import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.models as models
import vision_transformer_separateqkv
import swin_transformer_separateqkv
import timm
from copy import deepcopy
import csv
import numpy as np

def parse():
    model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))
    model_names += ['vit_s_16', 'convnext_t', 'swin_t']
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR', nargs='*',
                        help='path(s) to dataset (if one path is provided, it is assumed\n' +
                       'to have subdirectories named "train" and "val"; alternatively,\n' +
                       'train and val paths can be specified directly by providing both paths as arguments)')
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                        choices=model_names,
                        help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N', help='mini-batch size per process (default: 256)')
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='Initial learning rate.')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--gradient-accumulation', default=1, type=int, help='gradient accumulation number')

    parser.add_argument('--betas', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA',
                        help='adamw betas')
    parser.add_argument('--adamw', action='store_true',
                        help='Use AdamW optimizer instead of SGD momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--print-freq', '-p', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--deterministic', action='store_true')

    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument('--sync_bn', action='store_true',
                        help='enabling apex sync BN.')

    parser.add_argument('--opt-level', type=str, default=None, help='O1 float16 with constant scale of 128, O2 float16 with default gradscaling but 128 for initial gradient norm measure, O3 bflaot16 automatic mixed precision without gradscaling probably requries Ampere or higher GPU architecture. O3 recommended, O1 and O2 does not seem to work on recent pytorch versions')
    parser.add_argument('--channels-last', type=bool, default=False)
    parser.add_argument('--warmup', type=int, default=0,
                        help='Number of warmup epochs')
    parser.add_argument('--scale-lr', action='store_true',
                        help='adjust layer-wise lr.')
    parser.add_argument('--grad-model-path', default='', type=str,
                        help='path to initial gradient model checkpoint')
    parser.add_argument('--grad-clip', action='store_true',
                        help='Perform gradient clipping, default 1.0')
    parser.add_argument('--data-augmentation', action='store_true',
                        help='Perform data augmentation, RandAug: 2, 10, Mixup alpha 0.5 probability 0.5')
    parser.add_argument('--label-smoothing', action='store_true',
                        help='Perform label smoothing 0.1')

    args = parser.parse_args()
    return args

# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]

def main():
    global best_prec1, args
    best_prec1 = 0
    args = parse()

    if not len(args.data):
        raise Exception("error: No data set provided")

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed or args.sync_bn:
        try:
            global DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Distributed Data Parallel import fail.")

    print("opt_level = {}".format(args.opt_level))
    print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))

    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model

    if args.arch == 'vit_s_16':
        print("=> creating model '{}'".format(args.arch))
        if args.data_augmentation:
#            model = timm.create_model('vit_small_patch16_224')
            model = vision_transformer_separateqkv.vit_small_patch16_224()
        else:
#            model = timm.create_model('vit_small_patch16_224', drop_rate=0.1, drop_path_rate=0.1)
            model = vision_transformer_separateqkv.vit_small_patch16_224(drop_rate=0.1, proj_drop_rate=0.1, pos_drop_rate=0.1, drop_path_rate=0.1)
    elif args.arch == 'convnext_t':
        print("=> creating model '{}'".format(args.arch))
        if args.scale_lr:
            model = timm.create_model('convnext_tiny', drop_path_rate=0.1, ls_init_value=1) #layer scaling of 1
        else:
            model = timm.create_model('convnext_tiny', drop_path_rate=0.1)
    elif args.arch == 'swin_t':
        print("=> creating model '{}'".format(args.arch))
        if args.scale_lr:
            model = swin_transformer_separateqkv.swin_tiny_patch4_window7_224(drop_path_rate=0.2)
        else:
#            model = timm.create_model('swin_tiny_patch4_window7_224', drop_path_rate=0.2)
            model = swin_transformer_separateqkv.swin_tiny_patch4_window7_224(drop_path_rate=0.2)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
#        print("Converting BN to synced BN")
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if hasattr(torch, 'channels_last') and  hasattr(torch, 'contiguous_format'):
        if args.channels_last:
            memory_format = torch.channels_last
        else:
            memory_format = torch.contiguous_format
        model = model.cuda().to(memory_format=memory_format)
    else:
        model = model.cuda()

    param_list = [{'params': [p], 'lr_scale': 1.0, 'weight_decay': args.weight_decay} for i, p in enumerate(model.parameters())]

#    print(sum(p.numel() for p in model.parameters()))

    if args.adamw:
        optimizer = torch.optim.AdamW(param_list, args.lr, betas=(args.betas[0], args.betas[1]), weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(param_list, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    assert len(param_list) == len(list(model.parameters()))

    if "O1" in args.opt_level or "O2" in args.opt_level:
        #constant grad scaler, to avoid inf handling logic at beginning. 128 is value used for nvidia resnet50 imagenet example
        global scaler
        scaler = torch.amp.GradScaler("cuda", init_scale = 128.0, growth_factor = 1.1, backoff_factor = 0.9, growth_interval = 99999999999)

    if args.distributed:
        model = DDP(model)

    model = torch.compile(model)

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                global best_prec1
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer']) #can overide assinded layer-wise lrs
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
        resume()

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir= args.data[1]

    #timm dataloader with almost basic inception-style preprocessing
    #bicubic instead of random interpolation
    #imagenet inception mean, std instead of default
    dataset_train = timm.data.create_dataset(
        '', root=args.data[0], split='train', is_training=True,
        class_map='',
        download=True,
        batch_size=args.batch_size,
        repeats=0)
    dataset_eval = timm.data.create_dataset(
        '', root=args.data[0], split='validation', is_training=False,
        class_map='',
        download=True,
        batch_size=0)

    if args.data_augmentation:
        args.mixup = 0.5
        args.aa = 'rand-m15-n2' #internally limited to m10 by timm
    else:
        args.mixup = 0.0
        args.aa = None

    if args.label_smoothing:
        args.label_smoothing = 0.1
    else:
        args.label_smoothing = 0.0

    mixup_active = args.mixup > 0
    collate_fn = None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=0.0, cutmix_minmax=None,
            prob=1.0, switch_prob=0.5, mode='batch',
            label_smoothing=args.label_smoothing, num_classes=1000)
        if True: #args.prefetcher:
#            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = timm.data.FastCollateMixup(**mixup_args)

    loader_train = timm.data.create_loader(
        dataset_train,
        input_size=(3, 224, 224),
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=True,
        no_aug=False, #different from args.data_augmentation, basic aug has color jitter
        re_prob=0.0,
        re_mode='pixel',
        re_count=1,
        re_split=False,
        scale=[0.08, 1.0],
        ratio=[3./4., 4./3.],
        hflip=0.5,
        vflip=0.0,
        color_jitter=[32./255., 0.0, 0.5],#0.4,
        auto_augment=args.aa,
        num_aug_repeats=0,
        num_aug_splits=0,
        interpolation='bilinear',
        mean= (0.5, 0.5, 0.5), #data_config['mean'],
        std= (0.5, 0.5, 0.5),#data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=False,
        use_multi_epochs_loader=False,
        worker_seeding='all',
    )
    
    loader_eval = timm.data.create_loader(
        dataset_eval,
        input_size=(3, 224, 224),
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=True,
        interpolation='bilinear',
        mean= (0.5, 0.5, 0.5), #data_config['mean'],
        std= (0.5, 0.5, 0.5),#data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct= 0.875, #data_config['crop_pct'],
        pin_memory=False,
    )

    if mixup_active:
        from timm.loss import SoftTargetCrossEntropy
        train_loss_fn = SoftTargetCrossEntropy()
    else:
        train_loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    if args.scale_lr and args.resume == "":
        reinit(model)

    if args.evaluate:
        validate(loader_eval, model, validate_loss_fn)
        return

    save_checkpoint({
                    'epoch': 0,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': 0,
                    'optimizer' : optimizer.state_dict(),
                }, False, filename='init.pth.tar')

    if args.scale_lr:
        if args.grad_model_path:
            grad_model = deepcopy(model)
            grad_model.train()
            print("=> loading gradient model checkpoint '{}'".format(args.grad_model_path))
            grad_model.load_state_dict(torch.load(args.grad_model_path, map_location = lambda storage, loc: storage.cuda(args.gpu))['state_dict'])
        else:
            grad_model = grad_norm(loader_train, model, train_loss_fn, optimizer)
        lr_model = write_grad_stat(grad_model, model = model)
        with torch.no_grad():
            lmiter = iter(lr_model.parameters())
            for group in optimizer.param_groups:
                for p in group['params']:
                    lm = next(lmiter)
                    group['lr_scale'] = torch.mean(lm).item()
        del lr_model
        save_checkpoint({
                        'epoch': 0,
                        'arch': args.arch,
                        'state_dict': grad_model.state_dict(),
                    }, False, filename='initgradnorm.pth.tar')
        del grad_model

    if "O2" in args.opt_level:
        #reinit grad scaler with default values
        scaler = torch.amp.GradScaler("cuda")

    total_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        avg_train_time = train(loader_train, model, train_loss_fn, optimizer, epoch)
        total_time.update(avg_train_time)

        # evaluate on validation set
        [prec1, prec5] = validate(loader_eval, model, validate_loss_fn)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
#            is_best = prec1 > best_prec1
#            best_prec1 = max(prec1, best_prec1)
#            save_checkpoint({
#                'epoch': epoch + 1,
#                'arch': args.arch,
#                'state_dict': model.state_dict(),
#                'best_prec1': best_prec1,
#                'optimizer' : optimizer.state_dict(),
#            }, is_best)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(
                      prec1,
                      prec5,
                      args.total_batch_size / total_time.avg))

        save_checkpoint({
                    'epoch': epoch+1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': 0,
                    'optimizer' : optimizer.state_dict(),
                }, False, filename='checkpoint.pth.tar')

def reinit(model):
    from torchvision.models.convnext import LayerNorm2d as LayerNorm2d
    for i in model.modules():
        if isinstance(i, (nn.modules.batchnorm.BatchNorm2d, nn.modules.normalization.LayerNorm, LayerNorm2d)):
            nn.init.ones_(i.weight)
            nn.init.zeros_(i.bias)
        if isinstance(i, (nn.modules.conv.Conv2d, nn.modules.linear.Linear)):
            fi, fo = nn.init._calculate_fan_in_and_fan_out(i.weight)
            nn.init.normal_(i.weight, std = 1/fi**0.5)
            if i.bias != None:
                nn.init.zeros_(i.bias)
    if args.arch == 'resnet50':
        pass
    elif args.arch == 'vit_s_16' or args.arch == 'vit_s_16_separateqkv':
        nn.init.zeros_(model.cls_token)
        nn.init.zeros_(model.pos_embed)
    elif args.arch == 'convnext_tiny':
        pass
    elif args.arch == 'swin_t':
        for name, i in model.named_parameters():
            if 'relative_position_bias_table' in name:
                nn.init.zeros_(i)
            if 'absolute_pos_embed' in name:
                nn.init.zeros_(i)
    else:
        print("model not checked for special layers")
    return

def write_grad_stat(gradnormmodel, filename="stat_gradnorm", model = None):
    with torch.no_grad():
        stat_layerwise = []
        sum_layerwise = []
        paramnum = []
        for q in gradnormmodel.parameters():
            sum_layerwise.append(torch.sum(q).item())
            paramnum.append(torch.numel(q))
            stat_layerwise.append(sum_layerwise[-1]/paramnum[-1])

    l_paramnum = []
    l_name = []
    for name in gradnormmodel.named_parameters():
        l_name.append(name[0])
        l_paramnum.append(torch.numel(name[1]))

    lr_model = None
    if args.scale_lr:
        assert model != None
        lr_model = deepcopy(model)
        lr_model.train()
        for l in lr_model.parameters():
            nn.init.ones_(l)
        with torch.no_grad():
            gnsum = 0; psum = 0
            lwlr = []; lwlr_std = []; lwlr_max = []; lwlr_min = []; lwlr_25q = []; lwlr_50q = []; lwlr_75q = []
            for lm, q in zip(lr_model.parameters(), gradnormmodel.parameters()):
                lm.copy_(torch.div(lm, torch.pow(torch.mean(q), 0.5)))
                gnsum += torch.sum(lm).item()
                psum += torch.numel(q)
            pwratio = gnsum / psum
            for lm in lr_model.parameters():
                lm.copy_(lm/pwratio)
                lr_std, lr_mean = torch.std_mean(lm)
                lr_std, lr_mean = lr_std.item(), lr_mean.item()
                lwlr.append(lr_mean)
        with open("lwlr.csv", 'a') as file:
            csv.writer(file, delimiter=',').writerow(l_name) 
            csv.writer(file, delimiter=',').writerow(l_paramnum) 
            csv.writer(file, delimiter=',').writerow(lwlr)
    return lr_model

def grad_norm(loader_train, model, criterion, optimizer):
    batch_time = AverageMeter()
    # switch to train mode
    losses = AverageMeter()

    grad_model = deepcopy(model)
    grad_model.train()
    for q in grad_model.parameters():
        nn.init.zeros_(q)
        assert torch.is_nonzero(torch.sum(torch.abs(q))) == False

    model.train()
    end = time.time()

    train_loader_len = len(loader_train)
    for i, (input, target) in enumerate(loader_train):
#        if not args.prefetcher:
#            input, target = input.to(device), target.to(device)
#            if mixup_fn is not None:
#                input, target = mixup_fn(input, target)
        if args.channels_last:
            data = data.contiguous(memory_format=torch.channels_last)

        # compute output
        if args.opt_level is not None:
            if "O1" in args.opt_level or "O2" in args.opt_level:
                autocast_context = torch.amp.autocast(device_type="cuda")
            else:
                autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
            with autocast_context:
                output = model(input)
                loss = criterion(output, target)
        else:
            output = model(input)
            loss = criterion(output, target)
        loss /= args.gradient_accumulation
        if "O1" in args.opt_level or "O2" in args.opt_level:
#            with amp.scale_loss(loss, optimizer) as scaled_loss:
#                scaled_loss.backward()
            scaler.scale(loss).backward()
            if (i+1)%args.gradient_accumulation == 0:
                scaler.unscale_(optimizer)
#            scaler.step(optimizer)
                scaler.update()
        else:
            loss.backward()
#        optimizer.step()
        if (i+1)%args.gradient_accumulation == 0:
            with torch.no_grad():
                for p, q in zip(model.parameters(), grad_model.parameters()):
                    if p.grad is not None:
                        q.add_(torch.abs(p.grad), alpha=1)
            optimizer.zero_grad()
        if i%args.print_freq == 0 or i == train_loader_len-1:
            # check speed.
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data)
                prec1 = reduce_tensor(prec1)
                prec5 = reduce_tensor(prec5)
            else:
                reduced_loss = loss.data

            # to_python_float incurs a host<->device sync
            losses.update(to_python_float(reduced_loss)*args.gradient_accumulation, input.size(0))

            torch.cuda.synchronize()
            batch_time.update((time.time() - end)/args.print_freq)
            end = time.time()
            if args.local_rank == 0:
                print('Gradnorm: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Speed {2:.3f} ({3:.3f})\t'
                      'Loss {loss.val:.10f} ({loss.avg:.4f})'.format(
                       i, train_loader_len,
                       args.world_size*args.batch_size/batch_time.val,
                       args.world_size*args.batch_size/batch_time.avg,
                       batch_time=batch_time, loss=losses))
    optimizer.zero_grad()
    with torch.no_grad():
        for q in grad_model.parameters():
            q.mul_(1/train_loader_len)
    return grad_model

def train(loader_train, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()
    train_loader_len = len(loader_train)
    for i, (input, target) in enumerate(loader_train):
#        if not args.prefetcher:
#            input, target = input.to(device), target.to(device)
#            if mixup_fn is not None:
#                input, target = mixup_fn(input, target)
        if args.channels_last:
            data = data.contiguous(memory_format=torch.channels_last)

        adjust_learning_rate(optimizer, epoch, i, train_loader_len)

        if args.opt_level is not None:
            if "O1" in args.opt_level or "O2" in args.opt_level:
                autocast_context = torch.amp.autocast(device_type="cuda")
            else:
                autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
            with autocast_context:
                output = model(input)
                loss = criterion(output, target)
        else:
            output = model(input)
            loss = criterion(output, target)
        loss /= args.gradient_accumulation

        if "O1" in args.opt_level or "O2" in args.opt_level:
#        with amp.scale_loss(loss, optimizer) as scaled_loss:
#            scaled_loss.backward()
            scaler.scale(loss).backward()
            if (epoch*train_loader_len+i+1)%args.gradient_accumulation == 0:
#        scaler.unscale_(optimizer)
                if args.grad_clip:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        else:
            loss.backward()
            if (epoch*train_loader_len+i+1)%args.gradient_accumulation == 0:
                if args.grad_clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

        if i%args.print_freq == 0 or i == train_loader_len-1:
            # Every print_freq iterations, check the loss, accuracy, and speed.
            # For best performance, it doesn't make sense to print these metrics every
            # iteration, since they incur an allreduce and some host<->device syncs.

            # Measure accuracy
            if not args.data_augmentation:
                prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

            # Average loss and accuracy across processes for logging
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data)
                prec1 = reduce_tensor(prec1)
                prec5 = reduce_tensor(prec5)
            else:
                reduced_loss = loss.data

            # to_python_float incurs a host<->device sync
            losses.update(to_python_float(reduced_loss)*args.gradient_accumulation, input.size(0))
            if not args.data_augmentation:
                top1.update(to_python_float(prec1), input.size(0))
                top5.update(to_python_float(prec5), input.size(0))

            torch.cuda.synchronize()
            batch_time.update((time.time() - end)/args.print_freq)
            end = time.time()

            if args.local_rank == 0:
                if not args.data_augmentation:
                    print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Speed {3:.3f} ({4:.3f})\t'
                      'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       epoch, i, train_loader_len,
                       args.world_size*args.batch_size/batch_time.val,
                       args.world_size*args.batch_size/batch_time.avg,
                       batch_time=batch_time,
                       loss=losses, top1=top1, top5=top5))
                else:
                    print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Speed {3:.3f} ({4:.3f})\t'
                      'Loss {loss.val:.10f} ({loss.avg:.4f})'.format(
                       epoch, i, train_loader_len,
                       args.world_size*args.batch_size/batch_time.val,
                       args.world_size*args.batch_size/batch_time.avg,
                       batch_time=batch_time,
                       loss=losses))
        if i == train_loader_len-1:
            to_write = [losses.avg]
            if not args.data_augmentation:
                to_write.append(top1.avg)
                to_write.append(top5.avg)
            with open("train.csv", 'a') as file:
                csv.writer(file, delimiter=',').writerow(to_write)

    return batch_time.avg

def validate(loader_eval, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    val_loader_len = len(loader_eval)
    for i, (input, target) in enumerate(loader_eval):
        # compute output
        with torch.no_grad():
            output = model(input)
            loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

        if args.distributed:
            reduced_loss = reduce_tensor(loss.data)
            prec1 = reduce_tensor(prec1)
            prec5 = reduce_tensor(prec5)
        else:
            reduced_loss = loss.data

        losses.update(to_python_float(reduced_loss), input.size(0))
        top1.update(to_python_float(prec1), input.size(0))
        top5.update(to_python_float(prec5), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # TODO:  Change timings to mirror train().
        if args.local_rank == 0 and (i % args.print_freq == 0 or i==val_loader_len-1):
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Speed {2:.3f} ({3:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, val_loader_len,
                   args.world_size * args.batch_size / batch_time.val,
                   args.world_size * args.batch_size / batch_time.avg,
                   batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))

        if i == val_loader_len-1:
            to_write = [losses.avg, top1.avg, top5.avg]
            with open("test.csv", 'a') as file:
                csv.writer(file, delimiter=',').writerow(to_write)

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
        .format(top1=top1, top5=top5))

    return [top1.avg, top5.avg]


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.sum_square = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.sum_square += (val**2) * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, epoch, step, len_epoch):
    """cosine LR schedule"""
#    """Lr with decay at 30, 60, 80"""
    """Warmup"""
    warmup = args.warmup
    num_epoch = float(args.epochs)
    if epoch < warmup: 
        lr = float(1 + step + epoch*len_epoch)/(float(warmup)*len_epoch)
    else:
        epoch_nonwarmup = epoch-warmup
        num_epoch_nonwarmup = num_epoch - warmup
        lr = 0.5*(1+np.cos(np.pi*float(1 + step + epoch_nonwarmup*len_epoch)/(num_epoch_nonwarmup*len_epoch)))
#        factor = epoch // 30
#        if epoch >= 80:
#            factor = factor + 1
#        lr = 0.1**factor
    
    lr = args.lr*max(lr, 0.0)

    for param_group in optimizer.param_groups:
        if 'lr_scale' in param_group:
            param_group['lr'] = lr*param_group['lr_scale']
            continue
        else:
            param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= args.world_size
    return rt

if __name__ == '__main__':
    main()
