import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.models as models
import timm 
import lars
from copy import deepcopy
import csv
import numpy as np

def parse():
    model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR', nargs='*',
                        help='path(s) to dataset (if one path is provided, it is assumed\n' +
                       'to have subdirectories named "train" and "val"; alternatively,\n' +
                       'train and val paths can be specified directly by providing both paths as arguments)')
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                        choices=model_names,
                        help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N', help='mini-batch size per process (default: 256)')
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='Initial learning rate.')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--gradient-accumulation', default=1, type=int, help='gradient accumulation number')

    parser.add_argument('--betas', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA',
                        help='adamw betas')
    parser.add_argument('--adamw', action='store_true',
                        help='Use AdamW optimizer instead of SGD momentum')
    parser.add_argument('--lars', action='store_true',
                        help='Use lars optimizer instead of SGD momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    parser.add_argument('--print-freq', '-p', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--deterministic', action='store_true')

    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument('--sync_bn', action='store_true',
                        help='enabling apex sync BN.')

    parser.add_argument('--opt-level', type=str, default=None, help='O1 float16 with constant scale of 128, O2 float16 with default gradscaling but 128 for initial gradient norm measure, O3 bflaot16 automatic mixed precision without gradscaling probably requries Ampere or higher GPU architecture. O3 recommended, O1 and O2 does not seem to work on recent pytorch versions')
    parser.add_argument('--channels-last', type=bool, default=False)
    parser.add_argument('--warmup', type=int, default=0,
                        help='Number of warmup epochs')
    parser.add_argument('--grad-clip', action='store_true',
                        help='Perform gradient clipping, default 1.0')
    parser.add_argument('--data-augmentation', action='store_true',
                        help='Perform data augmentation, RandAug: 2, 10, Mixup alpha 0.5 probability 0.5')
    parser.add_argument('--label-smoothing', action='store_true',
                        help='Perform label smoothing 0.1')

    args = parser.parse_args()
    return args

#built upon pytorch, nvidia and timm imagenet training code 

# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]

def main():
    global best_prec1, args
    best_prec1 = 0
    args = parse()

    if not len(args.data):
        raise Exception("error: No data set provided")

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed or args.sync_bn:
        try:
            global DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Distributed Data Parallel import fail.")

    print("opt_level = {}".format(args.opt_level))
    print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))

    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
#        print("Converting BN to synced BN")
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if hasattr(torch, 'channels_last') and  hasattr(torch, 'contiguous_format'):
        if args.channels_last:
            memory_format = torch.channels_last
        else:
            memory_format = torch.contiguous_format
        model = model.cuda().to(memory_format=memory_format)
    else:
        model = model.cuda()

    param_list = [{'params': [p], 'lr_scale': 1.0, 'weight_decay': args.weight_decay} for i, p in enumerate(model.parameters())]

    if args.adamw:
        optimizer = torch.optim.AdamW(param_list, args.lr, betas=(args.betas[0], args.betas[1]), weight_decay=args.weight_decay)
    elif args.lars:
        optimizer = lars.Lars(param_list, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, trust_clip=True)
    else:
        optimizer = torch.optim.SGD(param_list, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    assert len(param_list) == len(list(model.parameters()))

    if "O1" in args.opt_level or "O2" in args.opt_level:
        #constant grad scaler, to avoid inf handling logic at beginning. 128 is value used for nvidia resnet50 imagenet example
        global scaler
        scaler = torch.amp.GradScaler("cuda", init_scale = 128.0, growth_factor = 1.1, backoff_factor = 0.9, growth_interval = 99999999999)

    if args.distributed:
        model = DDP(model)

    model = torch.compile(model)

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                global best_prec1
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer']) #can overide assinded layer-wise lrs
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
        resume()

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir= args.data[1]

    #timm dataloader with almost basic inception-style preprocessing
    #bicubic instead of random interpolation
    #imagenet inception mean, std instead of default
    dataset_train = timm.data.create_dataset(
        '', root=args.data[0], split='train', is_training=True,
        class_map='',
        download=True,
        batch_size=args.batch_size,
        repeats=0)
    dataset_eval = timm.data.create_dataset(
        '', root=args.data[0], split='validation', is_training=False,
        class_map='',
        download=True,
        batch_size=0)

    if args.data_augmentation:
        args.mixup = 0.5
        args.aa = 'rand-m15-n2' #rand magnitude internally limited to 10 by timm
    else:
        args.mixup = 0.0
        args.aa = None

    if args.label_smoothing:
        args.label_smoothing = 0.1
    else:
        args.label_smoothing = 0.0

    mixup_active = args.mixup > 0
    collate_fn = None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=0.0, cutmix_minmax=None,
            prob=1.0, switch_prob=0.5, mode='batch',
            label_smoothing=args.label_smoothing, num_classes=1000)
        if True: #args.prefetcher:
#            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = timm.data.FastCollateMixup(**mixup_args)

    loader_train = timm.data.create_loader(
        dataset_train,
        input_size=(3, 224, 224),
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=True,
        no_aug=False, #different from args.data_augmentation, basic aug has color jitter
        re_prob=0.0,
        re_mode='pixel',
        re_count=1,
        re_split=False,
        scale=[0.08, 1.0],
        ratio=[3./4., 4./3.],
        hflip=0.5,
        vflip=0.0,
        color_jitter=[32./255., 0.0, 0.5],#0.4,
        auto_augment=args.aa,
        num_aug_repeats=0,
        num_aug_splits=0,
        interpolation='bilinear',
        mean= (0.5, 0.5, 0.5), #data_config['mean'],
        std= (0.5, 0.5, 0.5),#data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=False,
        use_multi_epochs_loader=False,
        worker_seeding='all',
    )
    
    loader_eval = timm.data.create_loader(
        dataset_eval,
        input_size=(3, 224, 224),
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=True,
        interpolation='bilinear',
        mean= (0.5, 0.5, 0.5), #data_config['mean'],
        std= (0.5, 0.5, 0.5),#data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct= 0.875, #data_config['crop_pct'],
        pin_memory=False,
    )

    if mixup_active:
        from timm.loss import SoftTargetCrossEntropy
        train_loss_fn = SoftTargetCrossEntropy()
    else:
        train_loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    if args.evaluate:
        validate(loader_eval, model, validate_loss_fn)
        return

    save_checkpoint({
                    'epoch': 0,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': 0,
                    'optimizer' : optimizer.state_dict(),
                }, False, filename='init.pth.tar')

    if "O2" in args.opt_level:
        #reinit grad scaler with default values
        scaler = torch.amp.GradScaler("cuda")

    total_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        avg_train_time = train(loader_train, model, train_loss_fn, optimizer, epoch)
        total_time.update(avg_train_time)

        # evaluate on validation set
        [prec1, prec5] = validate(loader_eval, model, validate_loss_fn)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
#            is_best = prec1 > best_prec1
#            best_prec1 = max(prec1, best_prec1)
#            save_checkpoint({
#                'epoch': epoch + 1,
#                'arch': args.arch,
#                'state_dict': model.state_dict(),
#                'best_prec1': best_prec1,
#                'optimizer' : optimizer.state_dict(),
#            }, is_best)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(
                      prec1,
                      prec5,
                      args.total_batch_size / total_time.avg))

        save_checkpoint({
                    'epoch': epoch+1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': 0,
                    'optimizer' : optimizer.state_dict(),
                }, False, filename='checkpoint.pth.tar')

def train(loader_train, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()
    train_loader_len = len(loader_train)

    stepsize = []
    for group in optimizer.param_groups:
        for p in group['params']:
            stepsize.append(0.0)
    num_steps = 0
    found_inf = False
    for i, (input, target) in enumerate(loader_train):
#        if not args.prefetcher:
#            input, target = input.to(device), target.to(device)
#            if mixup_fn is not None:
#                input, target = mixup_fn(input, target)
        if args.channels_last:
            data = data.contiguous(memory_format=torch.channels_last)

        adjust_learning_rate(optimizer, epoch, i, train_loader_len)

        if args.opt_level is not None:
            if "O1" in args.opt_level or "O2" in args.opt_level:
                autocast_context = torch.amp.autocast(device_type="cuda")
            else:
                autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
            with autocast_context:
                output = model(input)
                loss = criterion(output, target)
        else:
            output = model(input)
            loss = criterion(output, target)
        loss /= args.gradient_accumulation

        if "O1" in args.opt_level or "O2" in args.opt_level:
#        with amp.scale_loss(loss, optimizer) as scaled_loss:
#            scaled_loss.backward()
            scaler.scale(loss).backward()
            if (epoch*train_loader_len+i+1)%args.gradient_accumulation == 0:
#        scaler.unscale_(optimizer)
                if args.grad_clip:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                found_inf = sum(v.item() for v in scaler._per_optimizer_states[id(optimizer)]["found_inf_per_device"].values())
                scaler.update()
                optimizer.zero_grad()
        else:
            loss.backward()
            if (epoch*train_loader_len+i+1)%args.gradient_accumulation == 0:
                if args.grad_clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

        if i%args.print_freq == 0 or i == train_loader_len-1:
            # Every print_freq iterations, check the loss, accuracy, and speed.
            # For best performance, it doesn't make sense to print these metrics every
            # iteration, since they incur an allreduce and some host<->device syncs.

            # Measure accuracy
            if not args.data_augmentation:
                prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

            # Average loss and accuracy across processes for logging
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data)
                prec1 = reduce_tensor(prec1)
                prec5 = reduce_tensor(prec5)
            else:
                reduced_loss = loss.data

            # to_python_float incurs a host<->device sync
            losses.update(to_python_float(reduced_loss)*args.gradient_accumulation, input.size(0))
            if not args.data_augmentation:
                top1.update(to_python_float(prec1), input.size(0))
                top5.update(to_python_float(prec5), input.size(0))

            torch.cuda.synchronize()
            batch_time.update((time.time() - end)/args.print_freq)
            end = time.time()

            if args.local_rank == 0:
                if not args.data_augmentation:
                    print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Speed {3:.3f} ({4:.3f})\t'
                      'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       epoch, i, train_loader_len,
                       args.world_size*args.batch_size/batch_time.val,
                       args.world_size*args.batch_size/batch_time.avg,
                       batch_time=batch_time,
                       loss=losses, top1=top1, top5=top5))
                else:
                    print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Speed {3:.3f} ({4:.3f})\t'
                      'Loss {loss.val:.10f} ({loss.avg:.4f})'.format(
                       epoch, i, train_loader_len,
                       args.world_size*args.batch_size/batch_time.val,
                       args.world_size*args.batch_size/batch_time.avg,
                       batch_time=batch_time,
                       loss=losses))
        if (epoch*train_loader_len+i+1)%args.gradient_accumulation == 0 and found_inf == False:
            bias_correction2 = 1 - 0.999 ** ((epoch*train_loader_len+i+1)/args.gradient_accumulation)
            lindex = 0
            num_steps += 1
            for group in optimizer.param_groups:
                for p in group['params']:
                    if type(optimizer) == lars.Lars:
                        stepsize[lindex]+= optimizer.state[p]['lars_stepsize'].item()
                    elif type(optimizer) == torch.optim.AdamW:
                        stepsize[lindex]+= 1/torch.mean((optimizer.state[p]["exp_avg_sq"]/bias_correction2).sqrt()+1e-8).item()
                    lindex += 1
        if i == train_loader_len-1:
            to_write = [losses.avg]
            if not args.data_augmentation:
                to_write.append(top1.avg)
                to_write.append(top5.avg)
            with open("train.csv", 'a') as file:
                csv.writer(file, delimiter=',').writerow(to_write)
            for li in range(len(stepsize)):
                stepsize[li] = stepsize[li]/num_steps
            if type(optimizer) == lars.Lars:
                with open("lars.csv", 'a') as file:
                    csv.writer(file, delimiter=',').writerow(stepsize)
            elif type(optimizer) == torch.optim.AdamW:
                with open("adam2m.csv", 'a') as file:
                    csv.writer(file, delimiter=',').writerow(stepsize)
    return batch_time.avg

def validate(loader_eval, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    val_loader_len = len(loader_eval)
    for i, (input, target) in enumerate(loader_eval):
        # compute output
        with torch.no_grad():
            output = model(input)
            loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

        if args.distributed:
            reduced_loss = reduce_tensor(loss.data)
            prec1 = reduce_tensor(prec1)
            prec5 = reduce_tensor(prec5)
        else:
            reduced_loss = loss.data

        losses.update(to_python_float(reduced_loss), input.size(0))
        top1.update(to_python_float(prec1), input.size(0))
        top5.update(to_python_float(prec5), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # TODO:  Change timings to mirror train().
        if args.local_rank == 0 and (i % args.print_freq == 0 or i==val_loader_len-1):
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Speed {2:.3f} ({3:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, val_loader_len,
                   args.world_size * args.batch_size / batch_time.val,
                   args.world_size * args.batch_size / batch_time.avg,
                   batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))

        if i == val_loader_len-1:
            to_write = [losses.avg, top1.avg, top5.avg]
            with open("test.csv", 'a') as file:
                csv.writer(file, delimiter=',').writerow(to_write)

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
        .format(top1=top1, top5=top5))

    return [top1.avg, top5.avg]


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.sum_square = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.sum_square += (val**2) * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, epoch, step, len_epoch):
    """cosine LR schedule"""
    """Warmup"""
    warmup = args.warmup
    num_epoch = float(args.epochs)
    if epoch < warmup: 
        lr = float(1 + step + epoch*len_epoch)/(float(warmup)*len_epoch)
    else:
        epoch_nonwarmup = epoch-warmup
        num_epoch_nonwarmup = num_epoch - warmup
        lr = 0.5*(1+np.cos(np.pi*float(1 + step + epoch_nonwarmup*len_epoch)/(num_epoch_nonwarmup*len_epoch)))
    
    lr = args.lr*max(lr, 0.0)

    for param_group in optimizer.param_groups:
        if 'lr_scale' in param_group:
            param_group['lr'] = lr*param_group['lr_scale']
            continue
        else:
            param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= args.world_size
    return rt

if __name__ == '__main__':
    main()
