import argparse
import os
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import resnet as resnet

from torch.utils.data import Dataset, DataLoader
import util
from warnings import simplefilter
from GradualWarmupScheduler import *
from augmentations import augment_list
from PIL import Image
from wide_resnet import WideResNet
from datasets import WeightedAugmentedDataset, IndexedDataset, get_train_dataset, get_val_dataset
from models import MNIST_Net
from jacobian import compute_jacobian


SUBSET_ALGOS = ['coreset-weighted', 'coreset-uniform', 'all', 'random', 'largest-ce-loss']
AUGMENT_ALGOS = ['coreset-uniform', 'coreset-weighted', 'none', 'same', 'random', 'all', 'largest-ce-loss']

# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)
np.seterr(all='ignore')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152

resnet_model_names = sorted(name for name in resnet.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnet.__dict__[name]))

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='wide-resnet-28-10',
                    help='model architecture: ')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', '-m', type=float, metavar='M', default=0.9,
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
                    metavar='N', help='print frequency (default: 20)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--half', dest='half', action='store_true',
                    help='use half-precision(16-bit) ')
parser.add_argument('--save-dir', dest='save_dir',
                    help='The directory used to save the trained models',
                    default='save_temp', type=str)
parser.add_argument('--save-every', dest='save_every',
                    help='Saves checkpoints at every specified number of epochs',
                    type=int, default=300)  # default=10)
parser.add_argument('-R', '--R', dest='R', type=int, metavar='R',
                    help='interval to select subset', default=1)
parser.add_argument('--gpu', default='0', type=str, help='The GPU to be used')
parser.add_argument('--dataset', dest='dataset', type=str, default="CIFAR10", help='Dataset to use')
parser.add_argument('--augment_algo', dest='augment_algo', type=str, default="none", help='Augmentation algorithm', choices=AUGMENT_ALGOS)
parser.add_argument('--subset_algo', dest='subset_algo', type=str, default="all", help='Subset algorithm', choices=SUBSET_ALGOS)
parser.add_argument('--subset_size', '-s', dest='subset_size', type=float, help='size of the subset', default=1.0)
parser.add_argument('--augment_size', '-as', dest='augment_size', type=float, help='size of the augmentation set', default=1.0)
parser.add_argument('--st_grd', '-stg', type=float, help='stochastic greedy', default=0)
parser.add_argument('--smtk', type=int, help='smtk', default=1)
parser.add_argument('--ig', type=str, help='ig method', default='sgd', choices=['sgd, adam, adagrad'])
parser.add_argument('--lr_schedule', '-lrs', type=str, help='learning rate schedule', default='mile',
                    choices=['mile', 'exp', 'cnt', 'step', 'cosine'])
parser.add_argument('--gamma', type=float, default=-1, help='learning rate decay parameter')
parser.add_argument('--lag', type=int, help='update lags', default=1)
parser.add_argument('--runs', type=int, help='num runs', default=1)
parser.add_argument('--warm', '-w', dest='warm_start', action='store_true', help='warm start learning rate ')
parser.add_argument('--start-subset', '-st', default=0, type=int, metavar='N', help='start subset selection')
parser.add_argument('--save_subset', dest='save_subset', action='store_true', help='save_subset')
parser.add_argument('-C', '--augment-c', dest='C', type=int, default=1, help='Number of augmented examples to generate per example')
parser.add_argument('-L', '--augment-l', dest='L', type=int, default=2, help='Number of augmentations to perform per generated example')
parser.add_argument('-S', '--augment-s', dest='S', type=int, default=1, help='Number of augmented examples to select')
parser.add_argument('--gamma-coreset', default=0.0, type=float,
                    metavar='GAMMA', help='Weight of diversity term in coreset')
parser.add_argument('--output_dir', dest='output_dir', type=str, default="output", help='Output directory')
parser.add_argument('--use_linear', dest='use_linear', action='store_true', help='Linear layer for coreset gradient approximation')


def get_model(arch, CLASS_NUM):
    if arch in resnet_model_names:
        model = torch.nn.DataParallel(resnet.__dict__[args.arch](CLASS_NUM))
    elif arch == 'wide-resnet-28-10':
        model = torch.nn.DataParallel(WideResNet(28, 10, 0.3, CLASS_NUM))
    elif arch == 'mnist':
        model = torch.nn.DataParallel(MNIST_Net(CLASS_NUM))
    model.cuda()
    return model


def main():

    global args, best_prec1
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    print(f'--------- subset_algo: {args.subset_algo}, subset_size: {args.subset_size}, augment_algo: {args.augment_algo} augment_size: {args.augment_size}'
            f' R: {args.R} C: {args.C} L: {args.L} S: {args.S}  method: {args.ig}, moment: {args.momentum}, '
          f'lr_schedule: {args.lr_schedule}, stoch: {args.st_grd}, ---------------')

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    dataset = get_train_dataset(args.dataset)
    if args.dataset == "CIFAR10":
        CLASS_NUM = 10
    elif args.dataset == "CIFAR100":
        CLASS_NUM = 100
    elif args.dataset == 'SVHN':
        CLASS_NUM = 10
    elif args.dataset == 'MNIST':
        CLASS_NUM = 10
    elif args.dataset == 'Reduced-MNIST':
        CLASS_NUM = 3
    elif args.dataset == 'CIFAR10-IMB':
        CLASS_NUM = 10
    else:
        raise ValueError("No such dataset: {}".format(args.dataset))

    TRAIN_NUM = len(dataset)
    print(f"TRAIN NUM: {TRAIN_NUM}")

    model = get_model(args.arch, CLASS_NUM)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    indexed_dataset = IndexedDataset(args.dataset, train=True)
    train_dataset = WeightedAugmentedDataset(indexed_dataset.ds)

    indexed_loader = DataLoader(
        indexed_dataset,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        get_val_dataset(args.dataset),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    train_val_loader = torch.utils.data.DataLoader(
        get_train_dataset(args.dataset),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    train_criterion = nn.CrossEntropyLoss(reduction='none').cuda()  # (Note)
    val_criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        train_criterion.half()
        val_criterion.half()

    runs, best_run, best_run_loss, best_loss = args.runs, 0, 0, 1e10
    epochs = args.epochs
    train_loss, test_loss = np.zeros((runs, epochs)), np.zeros((runs, epochs))
    train_acc, test_acc = np.zeros((runs, epochs)), np.zeros((runs, epochs))
    train_time, data_time = np.zeros((runs, epochs)), np.zeros((runs, epochs))
    grd_time, sim_time = np.zeros((runs, epochs)), np.zeros((runs, epochs))
    not_selected_train = np.zeros((runs, epochs))
    not_selected_aug = np.zeros((runs, epochs))
    not_selected_all = np.zeros((runs, epochs))
    best_bs, best_gs = np.zeros(runs), np.zeros(runs)
    times_selected_train = np.zeros((runs, len(indexed_loader.dataset)))
    times_selected_aug = np.zeros((runs, len(indexed_loader.dataset)))
    times_selected_all = np.zeros((runs, len(indexed_loader.dataset)))
    avg_full_loss = np.zeros((runs, epochs))
    avg_loss_coreset_weighted = np.zeros((runs, epochs))
    avg_loss_coreset_uniform = np.zeros((runs, epochs))
    coreset_correct_ratio = np.zeros((runs, epochs))
    full_correct_ratio = np.zeros((runs, epochs))
    correct_coreset_label_breakdown = np.zeros((runs, epochs, CLASS_NUM))
    wrong_coreset_label_breakdown = np.zeros((runs, epochs, CLASS_NUM))
    correct_original_label_breakdown = np.zeros((runs, epochs, CLASS_NUM))
    wrong_original_label_breakdown = np.zeros((runs, epochs, CLASS_NUM))

    if args.save_subset:
        B = int(args.subset_size * TRAIN_NUM)
        selected_ndx = np.zeros((runs, epochs, B))
        selected_wgt = np.zeros((runs, epochs, B))

    if (args.lr_schedule == 'mile' or args.lr_schedule == 'cosine') and args.gamma == -1:
        lr = args.lr
        b = 0.1
    else:
        lr = args.lr
        b = args.gamma

    print(f'lr schedule: {args.lr_schedule}, epochs: {args.epochs}')
    print(f'lr: {lr}, b: {b}')

    for run in range(runs):
        best_prec1_all, best_loss_all, prec1 = 0, 1e10, 0

        model = get_model(args.arch, CLASS_NUM)

        best_prec1, best_loss = 0, 1e10

        if args.ig == 'adam':
            print('using adam')
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=args.weight_decay)
        elif args.ig == 'adagrad':
            optimizer = torch.optim.Adagrad(model.parameters(), lr, weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

        if args.lr_schedule == 'exp':
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=b, last_epoch=args.start_epoch - 1)
        elif args.lr_schedule == 'step':
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=b)
        elif args.lr_schedule == 'mile':
            milestones = np.array([100, 150])
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, last_epoch=args.start_epoch - 1, gamma=b)
        elif args.lr_schedule == 'cosine':
            # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
        else:  # constant lr
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs, gamma=1.0)

        if args.warm_start:
            print('Warm start learning rate')
            lr_scheduler_f = GradualWarmupScheduler(optimizer, 1.0, 20, lr_scheduler)
        else:
            print('No Warm start')
            lr_scheduler_f = lr_scheduler

        if args.arch in ['resnet1202', 'resnet110']:
            # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
            # then switch back. In this setup it will correspond for first epoch.
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr*0.1

        if args.evaluate:
            validate(val_loader, model, val_criterion)
            return

        for epoch in range(args.start_epoch, args.epochs):

            # train for one epoch
            print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))

            if epoch % args.R != 0:
                print(f"Using previous subset of length {len(subset)}")

                train_dataset.reset_augments()
                # train_dataset.init_subset(subset, weights=subset_weights, keep_augment=False)

                if args.save_subset:
                    selected_ndx[run, epoch], selected_wgt[run, epoch] = selected_ndx[run, epoch-1], selected_wgt[run, epoch-1]

                if args.augment_algo != 'none':
                    train_dataset.augment_subset(augment_subset, weights=augment_weights, C=args.C, L=args.L, S=args.S, model=model, bs=args.batch_size)
                    times_selected_aug[run][augment_subset] += 1
                    print(f"Augmenting previous subset of length {len(augment_subset)}")
                else:
                    times_selected_aug[run] = 0
                    print(f"No augmentation")

                not_selected_train[run, epoch] = not_selected_train[run, epoch-1]
                not_selected_aug[run, epoch] = not_selected_aug[run, epoch-1]
                not_selected_all[run, epoch] = not_selected_all[run, epoch-1]

                full_correct_ratio[run, epoch] = full_correct_ratio[run, epoch-1]
                coreset_correct_ratio[run, epoch] = coreset_correct_ratio[run, epoch-1]

                avg_full_loss[run, epoch] = avg_full_loss[run, epoch-1]
                avg_loss_coreset_uniform[run, epoch] = avg_loss_coreset_uniform[run, epoch-1]
                avg_loss_coreset_weighted[run, epoch] = avg_loss_coreset_weighted[run, epoch-1]

                grd_time[run, epoch] = grd_time[run, epoch-1]
                sim_time[run, epoch] = sim_time[run, epoch-1]

                correct_coreset_label_breakdown[run, epoch] = correct_coreset_label_breakdown[run, epoch-1]
                wrong_coreset_label_breakdown[run, epoch] = wrong_coreset_label_breakdown[run, epoch-1]
                correct_original_label_breakdown[run, epoch] = correct_original_label_breakdown[run, epoch-1]
                wrong_original_label_breakdown[run, epoch] = wrong_original_label_breakdown[run, epoch-1]

            else:
                # Start by running coresets if necessary:
                if 'coreset' in args.subset_algo or 'coreset' in args.augment_algo:
                    if 'coreset' in args.augment_algo and 'coreset' in args.subset_algo:
                        # NOTE: Does not support different sizes for coreset subset and coreset augmentation
                        assert args.subset_size == args.augment_size
                        B = int(args.subset_size * TRAIN_NUM)
                    elif 'coreset' in args.augment_algo and 'coreset' not in args.subset_algo:
                        B = int(args.augment_size * TRAIN_NUM)
                    else:
                        B = int(args.subset_size * TRAIN_NUM)

                    print(f"Running predictions for greedy coreset of size: {B}")
                    gradient_est, preds, labels = grad_predictions(indexed_loader, model, TRAIN_NUM, CLASS_NUM, args.use_linear)
                    coreset, coreset_weights, ordering_time, similarity_time = get_coreset(args, gradient_est, labels, TRAIN_NUM, B, CLASS_NUM, normalize_weights=True)
                    grd_time[run, epoch], sim_time[run, epoch] = ordering_time, similarity_time

                    full_correct_ratio_epoch, coreset_correct_ratio_epoch, full_loss, coreset_loss_uniform, coreset_loss_weighted = get_subset_loss_stats(preds, labels, coreset, weights=coreset_weights)
                    full_correct_ratio[run, epoch] = full_correct_ratio_epoch
                    coreset_correct_ratio[run, epoch] = coreset_correct_ratio_epoch
                    avg_full_loss[run, epoch] = full_loss
                    avg_loss_coreset_uniform[run, epoch] = coreset_loss_uniform
                    avg_loss_coreset_weighted[run, epoch] = coreset_loss_weighted

                elif 'largest-ce-loss' in args.augment_algo or 'largest-ce-loss' in args.subset_algo:
                    if 'largest-ce-loss' in args.augment_algo and 'largest-ce-loss' in args.subset_algo:
                        assert args.subset_size == args.augment_size
                        B = int(args.subset_size * TRAIN_NUM)
                    elif 'largest-ce-loss' in args.augment_algo and 'largest-ce-loss' not in args.subset_algo:
                        B = int(args.augment_size * TRAIN_NUM)
                    else:
                        B = int(args.subset_size * TRAIN_NUM)

                    print(f"Selecting {B} examples with highest CE loss")
                    criterion = torch.nn.NLLLoss(reduction='none')
                    _, preds, labels = grad_predictions(indexed_loader, model, TRAIN_NUM, CLASS_NUM)

                    with torch.no_grad():
                        preds = torch.from_numpy(preds)
                        labels = torch.from_numpy(labels).long()
                        losses = criterion(preds, labels)
                        losses = losses.cpu().numpy()

                    # Sort smallest loss to largest loss, select last B
                    losses_ce_subset = np.argsort(losses)[-B:]

                ###########################
                # Pick training examples #
                ###########################
                print("Choosing training examples")
                if args.subset_size >= 1 or epoch < args.start_subset or args.subset_algo == 'all':
                    print('Training on all the data')
                    subset = np.array([i for i in range(TRAIN_NUM)])
                    subset_weights = np.ones(TRAIN_NUM)
                    train_dataset.init_subset(subset, weights=subset_weights)
                    times_selected_train[run] += 1

                elif args.subset_size < 1:
                    B = int(args.subset_size * TRAIN_NUM)

                    if args.subset_algo == 'coreset-uniform':
                        subset = coreset
                        subset_weights=np.ones(len(coreset))

                    elif args.subset_algo == 'coreset-weighted':
                        subset = coreset
                        subset_weights=coreset_weights

                    elif args.subset_algo == 'largest-ce-loss':
                        subset = losses_ce_subset
                        subset_weights=np.ones(len(subset))


                    elif args.subset_algo == 'random':
                        subset, subset_weights = get_random_subset(B, TRAIN_NUM)

                    times_selected_train[run][subset] += 1

                    train_dataset.init_subset(subset, weights=subset_weights)

                    if args.save_subset:
                        selected_ndx[run, epoch], selected_wgt[run, epoch] = subset, subset_weights
                else:
                    raise NotImplementedError()

                not_selected_train[run, epoch] = np.sum(times_selected_train[run] == 0) / len(times_selected_train[run]) * 100

                ###########################
                # Pick augmented examples #
                ###########################
                print("Choosing augmented examples")

                B_aug = int(args.augment_size * TRAIN_NUM)

                if args.augment_algo == 'none':
                    pass
                elif args.augment_size >= 1 or args.augment_algo == 'all':
                    if args.augment_size < 1 and args.augment_algo == 'all':
                        raise ValueError("Cannot use all augment when augment size < 1")
                    print("Augmenting all training examples")
                    augment_subset = np.array([i for i in range(TRAIN_NUM)])
                    augment_weights = np.ones(TRAIN_NUM)
                    augment_weights /= args.S

                elif args.augment_algo == 'coreset-uniform':
                    print("Augmenting coreset (uniform weights)")
                    augment_subset = coreset.copy()
                    augment_weights = np.ones(len(coreset))
                    augment_weights /= args.S

                elif args.augment_algo == 'coreset-weighted':
                    print("Augmenting coreset (weighted)")
                    augment_weights = coreset_weights.copy()
                    augment_weights /= args.S
                    augment_subset = coreset.copy()

                elif args.augment_algo == 'same':
                    if args.subset_algo == 'random':
                        print("Augmenting same random subset")
                        augment_weights = subset_weights.copy()
                        augment_weights /= args.S
                        augment_subset = subset.copy()
                    else:
                        raise ValueError("Only can use same option for random subset")

                elif args.augment_algo == 'random':
                    augment_subset, augment_weights = get_random_subset(B_aug, TRAIN_NUM)
                    augment_weights /= args.S

                elif args.augment_algo == 'largest-ce-loss':
                    augment_weights = np.ones(B_aug)
                    augment_weights /= args.S
                    augment_subset = losses_ce_subset.copy()

                else:
                    raise ValueError("No such augment algo")


                if args.augment_algo != 'none':
                    train_dataset.augment_subset(augment_subset, weights=augment_weights, C=args.C, L=args.L, S=args.S, model=model, bs=args.batch_size)
                    times_selected_aug[run][augment_subset] += 1
                else:
                    times_selected_aug[run] = 0

                times_selected_all[run] = times_selected_train[run] + times_selected_aug[run]

                not_selected_aug[run, epoch] = np.sum(times_selected_aug[run] == 0) / len(times_selected_aug[run]) * 100
                not_selected_all[run, epoch] = np.sum(times_selected_all[run] == 0) / len(times_selected_all[run]) * 100

                print(f'Train: {np.sum(times_selected_train[run] == 0) / len(times_selected_train[run]) * 100:.3f} % not selected yet')
                print(f'Aug: {np.sum(times_selected_aug[run] == 0) / len(times_selected_aug[run]) * 100:.3f} % not selected yet')
                print(f'All: {np.sum(times_selected_all[run] == 0) / len(times_selected_all[run]) * 100:.3f} % not selected yet')

            train_loader = DataLoader(
                train_dataset,
                batch_size=args.batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True)

            #if epoch == 0 or epoch == 50 or epoch == 100:
            #    svs = get_jacobian(train_loader, model)
            #    np.save(f"{args.dataset}_jacobian_{args.augment_algo}_{epoch}", svs)


            data_time[run, epoch], train_time[run, epoch] = train(
                train_loader, model, train_criterion, optimizer)

            lr_scheduler_f.step()

            # evaluate on validation set
            prec1, loss = validate(val_loader, model, val_criterion)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            # best_run = run if is_best else best_run
            best_prec1 = max(prec1, best_prec1)
            if best_prec1 > best_prec1_all:
                best_gs[run], best_bs[run] = lr, b
                best_prec1_all = best_prec1
            test_acc[run, epoch], test_loss[run, epoch] = prec1, loss

            ta, tl = validate(train_val_loader, model, val_criterion)
            # best_run_loss = run if tl < best_loss else best_run_loss
            best_loss = min(tl, best_loss)
            best_loss_all = min(best_loss_all, best_loss)
            train_acc[run, epoch], train_loss[run, epoch] = ta, tl

            if epoch > 0 and epoch % args.save_every == 0:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th'))

            print(f'run: {run}, subset_algo: {args.subset_algo} subset_size: {args.subset_size}, epoch: {epoch}, prec1: {prec1}, loss: {tl:.3f}, '
                    f'g: {lr:.3f}, b: {b:.3f}, augment_algo: {args.augment_algo} augment_size: {args.augment_size} '
                  f'best_prec1_gb: {best_prec1}, best_loss_gb: {best_loss:.3f}, best_run: {best_run};  '
                  f'best_prec_all: {best_prec1_all}, best_loss_all: {best_loss_all:.3f}, '
                  f'best_g: {best_gs[run]:.3f}, best_b: {best_bs[run]:.3f}, '
                  f'not selected_train %:{not_selected_train[run][epoch]:.3f}, '
                  f'not selected_aug %:{not_selected_aug[run][epoch]:.3f}, '
                  f'not selected_all %:{not_selected_all[run][epoch]:.3f}, ',
                  f'full_correct_ratio: {full_correct_ratio[run][epoch]:.3f}, ',
                  f'coreset_correct_ratio: {coreset_correct_ratio[run][epoch]:.3f}, ')

            grd = f'subsetalgo_{args.subset_algo}_{args.subset_size}'
            grd += f'_augmentalgo_{args.augment_algo}_{args.augment_size}'
            grd += f'_uselinear' if args.use_linear else ''
            grd += f'_C{args.C}_L{args.L}_S{args.S}_R{args.R}'
            grd += f'_gamma-div{args.gamma_coreset}'
            grd += f'_st_{args.st_grd}' if args.st_grd > 0 else ''
            grd += f'_warm' if args.warm_start > 0 else ''
            folder = f'{os.path.join(args.output_dir, args.dataset)}'

            if args.save_subset:
                print(
                    f'Saving the results to {folder}_{args.ig}_moment_{args.momentum}_{args.arch}'
                    f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}_subset')
                np.savez(f'{folder}_{args.ig}_moment_{args.momentum}_{args.arch}'
                         f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}_subset',
                 train_loss=train_loss, test_acc=test_acc, train_acc=train_acc, test_loss=test_loss,
                 data_time=data_time, train_time=train_time, grd_time=grd_time, sim_time=sim_time,
                 best_g=best_gs, best_b=best_bs, not_selected_train=not_selected_train,
                 not_selected_aug=not_selected_aug, not_selected_all=not_selected_all,
                 times_selected_train=times_selected_train, times_selected_aug=times_selected_aug,times_selected_all=times_selected_all,
                 subsets=selected_ndx, subset_weights=selected_wgt,
                 avg_full_loss=avg_full_loss, avg_loss_coreset_weighted=avg_loss_coreset_weighted, avg_loss_coreset_uniform=avg_loss_coreset_uniform,
                 full_correct_ratio=full_correct_ratio, coreset_correct_ratio=coreset_correct_ratio,
                 correct_coreset_label_breakdown=correct_coreset_label_breakdown, wrong_coreset_label_breakdown=wrong_coreset_label_breakdown,
                 correct_original_label_breakdown=correct_original_label_breakdown, wrong_original_label_breakdown=wrong_original_label_breakdown
                 )

            else:
                print(
                    f'Saving the results to {folder}_{args.ig}_moment_{args.momentum}_{args.arch}'
                    f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}')

                np.savez(f'{folder}_{args.ig}_moment_{args.momentum}_{args.arch}'
                         f'_{grd}_{args.lr_schedule}_start_{args.start_subset}_lag_{args.lag}',
                         train_loss=train_loss, test_acc=test_acc, train_acc=train_acc, test_loss=test_loss,
                         data_time=data_time, train_time=train_time, grd_time=grd_time, sim_time=sim_time,
                         best_g=best_gs, best_b=best_bs, not_selected_train=not_selected_train,
                         not_selected_aug=not_selected_aug, not_selected_all=not_selected_all,
                         times_selected_train=times_selected_train, times_selected_aug=times_selected_aug,times_selected_all=times_selected_all,
                         avg_full_loss=avg_full_loss, avg_loss_coreset_weighted=avg_loss_coreset_weighted, avg_loss_coreset_uniform=avg_loss_coreset_uniform,
                         full_correct_ratio=full_correct_ratio, coreset_correct_ratio=coreset_correct_ratio,
                         correct_coreset_label_breakdown=correct_coreset_label_breakdown, wrong_coreset_label_breakdown=wrong_coreset_label_breakdown,
                         correct_original_label_breakdown=correct_original_label_breakdown, wrong_original_label_breakdown=wrong_original_label_breakdown
                         )

    print(np.max(test_acc, 1), np.mean(np.max(test_acc, 1)),
          np.min(not_selected_all, 1), np.mean(np.min(not_selected_all, 1)),
          np.min(not_selected_train, 1), np.mean(np.min(not_selected_train, 1)),
          np.min(not_selected_aug, 1), np.mean(np.min(not_selected_aug, 1)))


def train(train_loader, model, criterion, optimizer):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target, weight) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = input.cuda()
        weight = weight.cuda()
        target_var = target
        if args.half:
            input_var = input_var.half()

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        loss = (loss * weight).mean()  # (Note)

        # compute gradient and do SGD step
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    return data_time.sum, batch_time.sum


def get_jacobian(train_loader, model):
    out = []
    for i, (input, target, _) in enumerate(train_loader):
        input_var = input.cuda()
        jac = compute_jacobian(model, input_var)
        js = [j.flatten(start_dim=2).flatten(end_dim=1) for j in jac]
        js = torch.cat(js, 1).cpu().numpy()
        out.append(js)
        model.zero_grad()

    return np.concatenate(out)


def validate(val_loader, model, criterion):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()


    print(' * Prec@1 {top1.avg:.3f}' .format(top1=top1))

    return top1.avg, losses.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def grad_predictions(loader, model, TRAIN_NUM, CLASS_NUM, use_linear=False):
    """
    Get predictions
    """
    batch_time = AverageMeter()

    # switch to evaluate mode
    model.eval()

    preds = torch.zeros(TRAIN_NUM, CLASS_NUM).cuda()
    labels = torch.zeros(TRAIN_NUM, dtype=torch.int).cuda()
    if use_linear:
        lasts = torch.zeros(TRAIN_NUM, model.module.emb_dim()).cuda()

    end = time.time()
    with torch.no_grad():
        for i, (input, target, idx) in enumerate(loader):
            input_var = input.cuda()
            target = target.cuda()

            if args.half:
                input_var = input_var.half()

            if use_linear:
                out, emb = model(input_var, use_linear=True)
                preds[idx, :] = nn.Softmax(dim=1)(out)
                lasts[idx, :] = emb
            else:
                preds[idx, :] = nn.Softmax(dim=1)(model(input_var, use_linear=False))
            labels[idx] = target.int()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    if use_linear:
        g0 = preds - (torch.eye(CLASS_NUM).cuda())[labels.long()]
        g0_expand = torch.repeat_interleave(g0, model.module.emb_dim(), dim=1)
        g1 = g0_expand * lasts.repeat(1, CLASS_NUM)
        gradient_est = torch.cat((g0, g1), dim=1)
    else:
        gradient_est = preds - (torch.eye(CLASS_NUM).cuda())[labels.long()]

    return gradient_est.cpu().data.numpy(), preds.cpu().data.numpy(), labels.cpu().data.numpy()


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def get_coreset(args, gradient_est, labels, TRAIN_NUM, B, CLASS_NUM, normalize_weights=True):
    try:
        subset, subset_weights, _, _, ordering_time, similarity_time = util.get_orders_and_weights(
            B, gradient_est, 'euclidean', smtk=args.smtk, no=0, y=labels, stoch_greedy=args.st_grd,
            equal_num=False, gamma=args.gamma_coreset, num_classes=CLASS_NUM)
    except ValueError as e:
        print(e)
        print(f"WARNING: ValueError from coreset selection, choosing random subset for this epoch")
        subset, subset_weights = get_random_subset(B, TRAIN_NUM)
        ordering_time = 0
        similarity_time = 0

    if normalize_weights:
        subset_weights = subset_weights / np.sum(subset_weights) * len(subset_weights)

    if len(subset) != B:
        print(f"!!WARNING!! Selected subset of size {len(subset)} instead of {B}")
    print(f'FL time: {ordering_time:.3f}, Sim time: {similarity_time:.3f}')

    return subset, subset_weights, ordering_time, similarity_time


def get_random_subset(B, TRAIN_NUM):
    print(f'Selecting {B} element from the random subset of size: {TRAIN_NUM}')
    order = np.arange(0, TRAIN_NUM)
    np.random.shuffle(order)
    subset = order[:B]
    subset_weights = np.ones(len(subset))

    return subset, subset_weights


def get_subset_loss_stats(preds, labels, subset, weights=None):
    criterion = torch.nn.NLLLoss(reduction='none')
    full_correct_ratio = np.sum(preds.argmax(1) == labels) / len(preds)

    subset_preds = preds[subset]
    subset_labels = labels[subset]
    subset_correct_ratio = np.sum(subset_preds.argmax(1) == subset_labels) / len(subset_preds)

    with torch.no_grad():
        subset_preds = torch.from_numpy(subset_preds).cuda()
        subset_labels = torch.from_numpy(subset_labels).cuda().long()
        subset_loss = criterion(subset_preds, subset_labels)
        subset_loss = subset_loss.cpu().numpy()

        full_preds = torch.from_numpy(preds).cuda()
        full_labels = torch.from_numpy(labels).cuda().long()
        full_loss = criterion(full_preds, full_labels)
        full_loss = full_loss.cpu().numpy()

    full_loss_mean = full_loss.mean()
    subset_loss_uniform_mean = subset_loss.mean()

    if weights is not None:
        subset_loss_weighted_mean = (subset_loss * weights).mean()
        return full_correct_ratio, subset_correct_ratio, full_loss_mean, subset_loss_uniform_mean, subset_loss_weighted_mean

    else:
        return full_correct_ratio, subset_correct_ratio, full_loss_mean, subset_loss_uniform_mean


if __name__ == '__main__':
    main()


