import argparse
import os
import shutil
import sys
import time
from types import SimpleNamespace

import numpy as np
import random
import scipy

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import resnet
import resnet_cifar10
from tinyimagenet import TinyImageNet
#from vit_pytorch import ViT

from historic_sampler import HistoricSampler

model_names = ["resnet20", "resnet32", "resnet44", "resnet56", "resnet110", "resnet1202",
    "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "vit"]

print(model_names)

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
                    choices=model_names,
                    help='model architecture: ' + ' | '.join(model_names) +
                    ' (default: resnet32)')
parser.add_argument('--data',
        default='cifar10',
        choices=['cifar10', 'cifar100', 'imagenet',  'imagenet100', 'tinyimagenet'])
parser.add_argument('--optimizer',
        default='sgd',
        choices=['sgd', 'lars'])
parser.add_argument('--data_path', type=str, default='./data', help='dataset path')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--budget-epochs', default=-1, type=int, metavar='N',
                    help='number of total budget epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--full-sample-freq', default=-1, type=int, metavar='N')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--eval-batch-size', default=128, type=int,
                    help='eval mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--lr_schedule', default='cosine', choices=['cosine', 'onecycle'])
parser.add_argument('--min_lr', type=float, default=1e-4, help='minimum learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--historic_beta', default=0.9, type=float)
parser.add_argument('--print-freq', '-p', default=50, type=int,
                    metavar='N', help='print frequency (default: 50)')
parser.add_argument('--eval-print-freq', default=50, type=int,
                    metavar='N', help='eval print frequency (default: 50)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--dynamic_ratio', action='store_true')
parser.add_argument('--heuristic_ratio', action='store_true')
parser.add_argument('--heuristic_max', action='store_true')
parser.add_argument('--count_discount', default=1, type=float)
parser.add_argument('--drop_last', action='store_true',
                    help='whether to drop the last batch in each epoch')
parser.add_argument('--unbiased_correction', action='store_true')
parser.add_argument('--tiny_scaled', action='store_true')
parser.add_argument('--tiny_resized_crop', action='store_true',
                    help='whether to use RandomResizedCrop for tinyimagenet')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--half', dest='half', action='store_true',
                    help='use half-precision(16-bit) ')
parser.add_argument('--save-dir', dest='save_dir',
                    help='The directory used to save the trained models',
                    default='save_temp', type=str)
parser.add_argument('--save-every', dest='save_every',
                    help='Saves checkpoints at every specified number of epochs',
                    type=int, default=10)
parser.add_argument('--selection_policy',
        default='all',
        choices=['all', 'fast_forward', 'static_bernoulli', 'dynamic_bernoulli', 'historic_dynamic_bernoulli', 'fix_budget_historic_bernoulli', 'dynamic_budget_historic_bernoulli', 'fix_budget_oracle_bernoulli', 'dynamic_budget_oracle_bernoulli', 'fix_variance_oracle_bernoulli', 'static_choice_wo_replacement', 'fix_budget_oracle_choice', 'oracle_rank_bernoulli', 'historic_rank_bernoulli', 'fix_budget_historic_choice', 'infobatch', 'historic_sampler', 'fix_budget_extra_bernoulli', 'extra_rank_bernoulli', 'fix_budget_extra_choice'],
        help="The policy for data selection")
parser.add_argument('--warmup_factor',
        type=float,
        default=0.1)
parser.add_argument('--selection_warmup',
        type=float,
        default=-1)
parser.add_argument('--selection_factor',
        type=float,
        default=1,
        help="The selection factor for data selection")
parser.add_argument('--smooth_factor',
        type=float,
        default=0)
parser.add_argument('--label_smoothing',
        type=float,
        default=0)
parser.add_argument('--sampling_end',
        type=float,
        default=1)
parser.add_argument('--clip_scaling',
        type=float,
        default=1e10)
parser.add_argument('--seed', default=0, type=int,
                    metavar='N', help='random seed (default: 0)')
parser.add_argument('--lr_step_type', default='selected',
                    choices=['seen', 'selected'])
parser.add_argument('--not_divide_prob', action='store_true')
parser.add_argument('--count_duplicate', action='store_true',
                    help='Counts duplicate towards number of samples selected')
parser.add_argument('--scale_lr', action='store_true',
                    help='Scales lr by estimated gain')
parser.add_argument('--history_type', default='grad_sq',
                    choices=['grad', 'grad_sq', 'grad_gain', 'grad_sq_gain'],
                    help='The history statistics used')

best_prec1 = 0
global_total_time = 0
global_total_selected = 0
global_total_seen = 0
global_total_samples = 0
best_loss = 1e10
sampling_generator = torch.Generator()
data_generator = torch.Generator()

def set_random_seed(seed):
    global sampling_generator
    global data_generator
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    sampling_generator.manual_seed(seed)
    data_generator.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

class IndexedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        return index, *self.dataset[index]

    def __len__(self):
        return len(self.dataset)

class SequentialLR(torch.optim.lr_scheduler.SequentialLR):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def step(self, epoch):
        from bisect import bisect_right
        self.last_epoch = epoch
        idx = bisect_right(self._milestones, self.last_epoch)
        scheduler = self._schedulers[idx]

        _epoch = self.last_epoch
        if idx > 0:
            _epoch -= self._milestones[idx - 1]
        scheduler.step(_epoch)

        self._last_lr = scheduler.get_last_lr()

def has_custom_sampler(selection_policy):
    return selection_policy in ["infobatch", "historic_sampler"]

def main():
    print(" ".join(sys.argv))

    global args, best_prec1, best_loss
    args = parser.parse_args()
    print(args)

    set_random_seed(args.seed)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # cudnn.benchmark = True

    # This is ImageNet value which is sometimes also used
    #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])
    if args.data == "cifar10":
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2470, 0.2435, 0.2616])

        train_data = datasets.CIFAR10(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]), download=True)
        val_data = datasets.CIFAR10(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]))
    elif args.data == "cifar100":
        channel = 3
        im_size = (32, 32)
        num_classes = 100
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2673, 0.2564, 0.2762])

        train_data = datasets.CIFAR100(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]), download=True)
        val_data = datasets.CIFAR100(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]))
    elif args.data == "tinyimagenet":
        channel = 3
        im_size = (224, 224) if args.tiny_scaled else (32, 32)
        num_classes = 200
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        normalize = transforms.Normalize(mean=mean, std=std)

        #train_data.transform = transforms.Compose([train_data.transform, transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4)])
        if args.tiny_scaled:
            train_data = datasets.ImageFolder(root=os.path.join(args.data_path, 'tiny-imagenet-200/train'), transform = transforms.Compose([
		    transforms.RandomResizedCrop(224),
		    transforms.RandomHorizontalFlip(),
		    transforms.ToTensor(),
		    normalize,
                ])
            )
        elif args.tiny_resized_crop:
            train_data = datasets.ImageFolder(root=os.path.join(args.data_path, 'tiny-imagenet-200/train'), transform = transforms.Compose([
                    transforms.RandomResizedCrop(32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ])
            )
        else:
            train_data = datasets.ImageFolder(root=os.path.join(args.data_path, 'tiny-imagenet-200/train'), transform = transforms.Compose([
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                ])
            )

        if args.tiny_scaled:
            val_data = datasets.ImageFolder(root=os.path.join(args.data_path, 'tiny-imagenet-200/val_processed'), transform = transforms.Compose([
		    transforms.Resize(256),
		    transforms.CenterCrop(224),
		    transforms.ToTensor(),
		    normalize,
                ])
            )
        elif args.tiny_resized_crop:
            val_data = datasets.ImageFolder(root=os.path.join(args.data_path, 'tiny-imagenet-200/val_processed'), transform = transforms.Compose([
                    transforms.CenterCrop(56),
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    normalize,
                ])
            )
        else:
            val_data = datasets.ImageFolder(root=os.path.join(args.data_path, 'tiny-imagenet-200/val_processed'), transform = transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    normalize,
                ])
            )
    elif args.data == "imagenet":
        channel = 3
        im_size = (224, 224)
        num_classes = 1000
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_data = datasets.ImageNet(root=args.data_path, split="train", transform=transforms.Compose([
				transforms.RandomResizedCrop(224),
				transforms.RandomHorizontalFlip(),
				transforms.ToTensor(),
				normalize,
            ]))
        val_data = datasets.ImageNet(root=args.data_path, split="val", transform=transforms.Compose([
				transforms.Resize(256),
				transforms.CenterCrop(224),
				transforms.ToTensor(),
				normalize,
            ]))
    elif args.data == "imagenet100":
        channel = 3
        im_size = (224, 224)
        num_classes = 100
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_data = datasets.ImageNet(root=args.data_path, split="train", transform=transforms.Compose([
				transforms.RandomResizedCrop(224),
				transforms.RandomHorizontalFlip(),
				transforms.ToTensor(),
				normalize,
            ]))
        val_data = datasets.ImageNet(root=args.data_path, split="val", transform=transforms.Compose([
				transforms.Resize(256),
				transforms.CenterCrop(224),
				transforms.ToTensor(),
				normalize,
            ]))
    else:
        raise ValueError

    num_samples = len(train_data)
    global global_total_samples
    if args.budget_epochs != -1:
        global_total_samples = num_samples * args.budget_epochs
    else:
        args.budget_epochs = int(args.epochs * args.selection_factor)
        global_total_samples = int(num_samples * args.epochs * args.selection_factor)
    print("Total Samples: {total_samples}".format(total_samples=global_total_samples))

    if args.selection_policy == 'infobatch':
        from infobatch import InfoBatch
        train_data = IndexedDataset(train_data)
        train_data = InfoBatch(train_data, args.epochs, global_total_samples, 0.5, 0.875, args.unbiased_correction)
    elif args.selection_policy == 'historic_sampler':
        train_data = HistoricSampler(train_data, args.budget_epochs, args.historic_beta, args.smooth_factor, sampling_generator, args.full_sample_freq, 1-args.selection_factor, 10, args.selection_warmup, dynamic_ratio=args.dynamic_ratio, heuristic_ratio=args.heuristic_ratio, count_discount=args.count_discount, heuristic_max=args.heuristic_max)
    else:
        train_data = IndexedDataset(train_data)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size, shuffle=(not has_custom_sampler(args.selection_policy)),
        generator=data_generator,
        sampler=train_data.sampler if has_custom_sampler(args.selection_policy) else None,
        num_workers=args.workers, drop_last=args.drop_last, pin_memory=(args.data!="imagenet"))

    val_loader = torch.utils.data.DataLoader(
        val_data,
        batch_size=args.eval_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=(args.data!="imagenet"))

    train_state = SimpleNamespace()
    train_state.history_approx_grad = torch.zeros(len(train_loader.dataset), device="cuda")
    train_state.history_approx_grad_sq = torch.zeros(len(train_loader.dataset), device="cuda")
    train_state.history_count = torch.zeros(len(train_loader.dataset), device="cuda", dtype=torch.int)
    train_state.grad_sq_ema = 0
    train_state.grad_sq_ema_count = 0
    train_state.max_scaling = 0
    train_state.max_scaling_all = 0

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).cuda()
    per_sample_criterion = nn.CrossEntropyLoss(reduction='none', label_smoothing=args.label_smoothing).cuda()

    if args.arch == "vit":
        #model = torchvision.models.vit_b_16(image_size=32)
        model = ViT(
            image_size = im_size[0],
            patch_size = 4,
            num_classes = 10,
            dim = 512,
            depth = 6,
            heads = 8,
            mlp_dim = 512,
            dropout = 0.1,
            emb_dropout = 0.1)
    else:
        if args.data in ["imagenet", "imagenet100"] or args.tiny_scaled:
            model = torchvision.models.__dict__[args.arch](progress=False)
        else:
            try:
                #model = torch.nn.DataParallel(resnet.ResNet(args.arch, channel, num_classes, im_size))
                model_name = "ResNet"+args.arch[6:]
                model = torch.nn.DataParallel(resnet.__dict__[model_name](num_classes=num_classes))
            except:
                model = torch.nn.DataParallel(resnet_cifar10.__dict__[args.arch]())
        model.cuda()
    print("Model Parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.half:
        model.half()
        criterion.half()

    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == "lars":
        from lars import Lars
        optimizer = Lars(model.parameters(), lr=args.lr,
                        momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        raise ValueError

    #lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
    #                                                    milestones=[100, 150], last_epoch=args.start_epoch - 1)

    if args.warmup_factor == 0:
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, global_total_samples,
                                                                               eta_min=args.min_lr)
    else:
        if args.lr_schedule == "cosine":
            warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0, total_iters=global_total_samples*args.warmup_factor)
            cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, global_total_samples*(1-args.warmup_factor), eta_min=args.min_lr)
            lr_scheduler = SequentialLR(optimizer, [warmup_scheduler, cosine_scheduler], milestones=[global_total_samples*args.warmup_factor])
        elif args.lr_schedule == "onecycle":
            lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, total_steps=global_total_samples, div_factor=25, final_div_factor=10000, pct_start=0.3)

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr*0.1


    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    #for epoch in range(args.start_epoch, args.epochs):
    epoch = args.start_epoch-1
    next_eval_num_samples = num_samples
    while global_total_selected < global_total_samples:
        epoch += 1

        if global_total_selected >= global_total_samples*args.sampling_end:
            # TODO need a separate parameter for batch_size
            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=128, shuffle=(args.selection_policy!='infobatch'),
                generator=data_generator,
                sampler=train_data.sampler if args.selection_policy == 'infobatch' else None,
                num_workers=args.workers, pin_memory=True, drop_last=args.drop_last)
            args.selection_policy = "all"
        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(args, train_state, train_loader, model, criterion, per_sample_criterion, optimizer, lr_scheduler, epoch)
        #lr_scheduler.step()

        # evaluate on validation set
        if global_total_selected >= next_eval_num_samples:
            prec1, loss = validate(val_loader, model, criterion)
            next_eval_num_samples += num_samples

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            best_loss = min(loss, best_loss)

            print('Epoch: [{epoch}]\t'
                'Average Time: {average_time}\t'
                'Total Time: {total_time}\t'
                'Total Selected: {total_selected}\t'
                'Best Prec@1: {best_prec1:.3f}\t'
                'Best Loss: {best_loss:.4f}\t'
                .format(
                    epoch=epoch, average_time=global_total_time/(epoch+1), total_time=global_total_time,
                    total_selected=global_total_selected, best_prec1=best_prec1, best_loss=best_loss))

            if epoch > 0 and epoch % args.save_every == 0:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th'))

            save_checkpoint({
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, 'model.th'))


def train(args, train_state, train_loader, model, criterion, per_sample_criterion, optimizer, lr_scheduler, epoch):
    """
        Run one train epoch
    """

    global sampling_generator
    global data_generator
    global global_total_time
    global global_total_selected
    global global_total_seen
    global global_total_samples

    selection_policy = args.selection_policy
    selection_factor = args.selection_factor
    smooth_factor = args.smooth_factor
    lr_step_type = args.lr_step_type

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    sampling_entropy_meter = AverageMeter()
    estimated_gain = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    all_count = 0
    selected_count = 0
    filter_count = 0
    epoch_max_scaling = 0
    for i, (_indices, _input, _target) in enumerate(train_loader):
        if global_total_selected >= global_total_samples:
            break
        # measure data loading time
        data_time.update(time.time() - end)

        _target = _target.cuda()
        _input_var = _input.cuda()
        _target_var = _target
        if args.half:
            _input_var = _input_var.half()

        if selection_policy == "all" or has_custom_sampler(args.selection_policy):
            input = _input
            target = _target
            input_var = _input_var
            target_var = _target_var
            indices = _indices
            scaling_factor = 1/len(_indices)
            counts = 1
        else:
            if "oracle" in selection_policy:
                _output = model(_input_var)
                sample_loss = per_sample_criterion(_output, _target_var)
                sample_approx_grad = torch.linalg.norm(torch.autograd.grad(sample_loss.sum(), _output)[0], dim=-1).detach()/torch.sqrt(torch.tensor(2)) # normalize to [0, 1]
                sample_approx_grad_sq = sample_approx_grad**2
                sample_approx_grad_sq = (1-smooth_factor)*sample_approx_grad_sq + smooth_factor*sample_approx_grad_sq.mean()
                sample_approx_grad = torch.sqrt(sample_approx_grad_sq)
                sample_approx_grad = sample_approx_grad.double()
            elif "extra" in selection_policy:
                model.eval()
                _output = model(_input_var)
                model.train()

                values = nn.functional.softmax(_output.detach(), dim=-1)
                values[torch.arange(values.size(0)), _target_var] -= 1

                sample_approx_grad = torch.linalg.norm(values, dim=-1).detach()/torch.sqrt(torch.tensor(2)) # normalize to [0, 1]
                assert (sample_approx_grad<=1).all()
                sample_approx_grad_sq = sample_approx_grad**2
                sample_approx_grad_sq = (1-smooth_factor)*sample_approx_grad_sq + smooth_factor*sample_approx_grad_sq.mean()
                sample_approx_grad = torch.sqrt(sample_approx_grad_sq)
                sample_approx_grad = sample_approx_grad.double()
            else:
                history_count = train_state.history_count[_indices]
                if "sq" in args.history_type:
                    sample_approx_grad_sq = train_state.history_approx_grad_sq[_indices]
                    sample_approx_grad_sq = sample_approx_grad_sq/(1-torch.pow(args.historic_beta, history_count))
                else:
                    sample_approx_grad = train_state.history_approx_grad[_indices]
                    sample_approx_grad = sample_approx_grad/(1-torch.pow(args.historic_beta, history_count))
                    sample_approx_grad_sq = sample_approx_grad**2
                sample_approx_grad_sq[history_count==0] = 1
                sample_approx_grad_sq = (1-smooth_factor)*sample_approx_grad_sq + smooth_factor*sample_approx_grad_sq.mean()
                #discounting smooth factor
                #factor = (1-torch.pow(0.9, history_count))/(1-torch.pow(0.9, history_count+smooth_factor))
                #sample_approx_grad_sq = factor*sample_approx_grad_sq + (1-factor)*sample_approx_grad_sq.mean()
                sample_approx_grad = torch.sqrt(sample_approx_grad_sq)
                sample_approx_grad = sample_approx_grad.double()

            if selection_policy == 'static_bernoulli':
                sample_prob = torch.ones_like(_indices, device="cuda")*selection_factor
            elif selection_policy == 'dynamic_budget_historic_bernoulli':
                ref = sample_approx_grad/torch.sqrt(sample_approx_grad_sq.mean())
                sample_prob = sample_approx_grad/ref*selection_factor
                sample_prob = torch.minimum(sample_prob, torch.tensor(1))
            elif selection_policy == 'historic_dynamic_bernoulli':
                sample_prob = sample_approx_grad/torch.sqrt(train_state.grad_sq_ema+1e-8)*selection_factor
                sample_prob = torch.minimum(sample_prob, torch.tensor(1))
            elif selection_policy == 'static_choice_wo_replacement':
                choice_prob = torch.ones_like(_indices, device="cuda").double()/len(_indices)
            elif selection_policy in ['oracle_rank_bernoulli', 'historic_rank_bernoulli', 'extra_rank_bernoulli']:
                sort_idx = torch.argsort(sample_approx_grad)
                sample_prob = torch.zeros_like(_indices, dtype=torch.float, device="cuda")
                sample_prob[sort_idx] = (torch.arange(len(_indices), dtype=torch.float, device="cuda")+1)/len(_indices)
            elif selection_policy in ['fix_budget_oracle_choice', 'fix_budget_historic_choice', 'fix_budget_extra_choice']:
                choice_prob = sample_approx_grad/sample_approx_grad.sum()
            elif selection_policy in ['fix_budget_oracle_bernoulli', 'fix_budget_historic_bernoulli', 'fix_budget_extra_bernoulli']:
                val, _ = torch.sort(sample_approx_grad)
                budget = len(val)*selection_factor

                s = torch.cumsum(val, 0)/val # s >= 1
                full = len(val)-1-torch.arange(len(val), device="cuda")
                prob_v = (budget-full)/s
                ref = (val/prob_v)[(0 <= prob_v)&(prob_v <= 1)].min()

                assert ref > 0, f"ref value {ref} {prob_v}"

                sample_prob = sample_approx_grad/ref
                sample_prob = torch.minimum(sample_prob, torch.tensor(1))
            elif selection_policy == 'fix_variance_oracle_bernoulli':
                var = (1-selection_factor)/selection_factor*torch.sum(sample_approx_grad**2)

                lb = 0
                rb = 1/selection_factor
                for _ in range(50):
                    mid = (lb+rb)/2

                    _prob = sample_approx_grad/mid
                    _prob = torch.minimum(_prob, torch.tensor(1))

                    _var = torch.sum((1-_prob)/_prob*sample_approx_grad**2)

                    if _var > var:
                        rb = mid
                    else:
                        lb = mid
                ref = mid

                sample_prob = sample_approx_grad/ref
                sample_prob = torch.minimum(sample_prob, torch.tensor(1))
            elif selection_policy == 'dynamic_budget_oracle_bernoulli':
                ref = sample_approx_grad/torch.sqrt(sample_approx_grad_sq.mean())
                sample_prob = sample_approx_grad/ref*selection_factor
                sample_prob = torch.minimum(sample_prob, torch.tensor(1))
            else:
                raise ValueError

            if "choice" in selection_policy:
                selection_size = int(len(_indices)*selection_factor)
                choices = np.random.choice(np.arange(len(_indices)), size=selection_size, replace="wo_replacement" not in selection_policy, p=choice_prob.cpu())
                sampling_entropy = selection_size*scipy.stats.entropy(choice_prob.cpu())
                selection_mask, counts = np.unique(choices, return_counts=True)
                scaling_factor = 1/(choice_prob*selection_size)/len(_indices)
                ref_scaling = 1/selection_size
                scaling_factor = torch.minimum(scaling_factor, torch.tensor(ref_scaling*args.clip_scaling, device="cuda"))
                train_state.max_scaling_all = max(train_state.max_scaling_all, scaling_factor.max())

                base_approx_variance = sample_approx_grad_sq.mean()/len(_indices)
                approx_variance = (choice_prob*sample_approx_grad_sq*scaling_factor**2).sum()*selection_size+(1-1/selection_size)*base_approx_variance
                ref_approx_variance = sample_approx_grad_sq.mean()/selection_size+(1-1/selection_size)*base_approx_variance

                scaling_factor = scaling_factor[selection_mask]
                if args.not_divide_prob:
                    scaling_factor = ref_scaling
                scaling_factor *= torch.tensor(counts, device="cuda")
                gain = ref_approx_variance/approx_variance
                if args.scale_lr:
                    scaling_factor *= gain
            else:
                selection_size = int(len(_indices)*selection_factor)
                selection_mask = np.arange(len(_indices))[torch.rand(len(_indices), generator=sampling_generator) < sample_prob.cpu()]
                sampling_entropy = scipy.stats.entropy(torch.stack([sample_prob, 1-sample_prob]).cpu()).sum()
                counts = 1
                scaling_factor = 1/sample_prob/len(_indices)
                ref_scaling = 1/selection_size
                scaling_factor = torch.minimum(scaling_factor, torch.tensor(ref_scaling*args.clip_scaling, device="cuda"))
                train_state.max_scaling_all = max(train_state.max_scaling_all, scaling_factor.max())

                base_approx_variance = sample_approx_grad_sq.mean()/len(_indices)
                approx_variance = ((1-sample_prob)*sample_prob*sample_approx_grad_sq*scaling_factor**2).sum()+base_approx_variance
                ref_approx_variance = sample_approx_grad_sq.mean()/selection_size+(1-1/selection_size)*base_approx_variance

                scaling_factor = scaling_factor[selection_mask]
                if args.not_divide_prob:
                    scaling_factor = ref_scaling
                gain = ref_approx_variance/approx_variance
                if args.scale_lr:
                    scaling_factor *= gain
            estimated_gain.update(gain)
            sampling_entropy_meter.update(sampling_entropy)

            input = _input[selection_mask]
            target = _target[selection_mask]
            input_var = _input_var[selection_mask]
            target_var = _target_var[selection_mask]
            indices = _indices[selection_mask]

        all_count += len(_indices)
        if not args.count_duplicate:
            selected_count += len(indices)
            global_total_selected += len(indices)
        else:
            selected_count += counts.sum()
            global_total_selected += counts.sum()
        global_total_seen += len(_indices) * selection_factor

        # compute output
        #if selection_policy not in ['fix_budget_oracle_bernoulli', 'dynamic_budget_oracle_bernoulli']:
        if 'oracle' not in selection_policy:
            output = model(input_var)
        else:
            output = _output[selection_mask]
        sample_loss = per_sample_criterion(output, target_var)
        #sample_approx_grad = 1-torch.exp(-sample_loss.detach())
        sample_approx_grad = torch.linalg.norm(torch.autograd.grad(sample_loss.sum(), output, retain_graph=True)[0], dim=-1).detach()/torch.sqrt(torch.tensor(2))

        grad_sq_mean = (sample_approx_grad**2*scaling_factor).sum()
        #grad_sq_mean = (sample_approx_grad**2).mean()
        train_state.grad_sq_ema = 0.99*train_state.grad_sq_ema + 0.01*grad_sq_mean
        train_state.grad_sq_ema_count += 1
        try:
            epoch_max_scaling = max(epoch_max_scaling, scaling_factor.max())
            train_state.max_scaling = max(train_state.max_scaling, scaling_factor.max())
        except:
            epoch_max_scaling = max(epoch_max_scaling, scaling_factor)
            train_state.max_scaling = max(train_state.max_scaling, scaling_factor)

        if "gain" in args.history_type:
            grad_sq_ema = train_state.grad_sq_ema/(1-0.99**train_state.grad_sq_ema_count)
            train_state.history_approx_grad[indices] = args.historic_beta*train_state.history_approx_grad[indices]+(1-args.historic_beta)*sample_approx_grad/torch.sqrt(grad_sq_ema)
            train_state.history_approx_grad_sq[indices] = args.historic_beta*train_state.history_approx_grad_sq[indices]+(1-args.historic_beta)*(sample_approx_grad**2/grad_sq_ema)
        else:
            train_state.history_approx_grad[indices] = args.historic_beta*train_state.history_approx_grad[indices]+(1-args.historic_beta)*sample_approx_grad
            train_state.history_approx_grad_sq[indices] = args.historic_beta*train_state.history_approx_grad_sq[indices]+(1-args.historic_beta)*sample_approx_grad**2

        train_state.history_count[indices] += torch.tensor(counts, device="cuda")

        if not has_custom_sampler(args.selection_policy):
            loss = (sample_loss*scaling_factor).sum()
        elif args.selection_policy == 'infobatch':
            loss = train_loader.dataset.update(sample_loss)
        elif args.selection_policy == 'historic_sampler':
            loss = train_loader.dataset.update(indices, output, sample_loss, target_var)
        else:
            raise ValueError

        #loss = (sample_loss).mean()

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if lr_step_type == "selected":
            lr_scheduler.step(min(global_total_selected, global_total_samples-1))
        elif lr_step_type == "seen":
            lr_scheduler.step(min(global_total_seen, global_total_samples-1))
        else:
            raise ValueError

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        global_total_time += time.time() - end
        end = time.time()

        if (i+1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Sampling Entropy {entropy.val:.4f} ({entropy.avg:.4f})\t'
                  'Estimated Gain {estimated_gain.val:.4f} ({estimated_gain.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Grad_Sq_Ema: {grad_sq_ema}\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, entropy=sampling_entropy_meter, estimated_gain=estimated_gain, loss=losses, top1=top1, grad_sq_ema=train_state.grad_sq_ema))

    heuristic_ratio = torch.minimum(torch.sqrt(train_state.history_approx_grad_sq/train_state.history_approx_grad_sq.mean()), torch.tensor(1)).mean()
    smooth_heuristic_ratio = torch.minimum(torch.sqrt((1-smooth_factor)*train_state.history_approx_grad_sq/train_state.history_approx_grad_sq.mean()+smooth_factor), torch.tensor(1)).mean()



    history_count = train_state.history_count

    if "sq" in args.history_type:
        sample_approx_grad_sq = train_state.history_approx_grad_sq

        sample_approx_grad_sq = sample_approx_grad_sq/(1-torch.pow(args.historic_beta, history_count))
    else:
        sample_approx_grad = train_state.history_approx_grad
        sample_approx_grad = sample_approx_grad/(1-torch.pow(args.historic_beta, history_count))
        sample_approx_grad_sq = sample_approx_grad**2

    sample_approx_grad_sq[history_count==0] = 1

    history_count = history_count.float()
    quantiles = torch.arange(21, device="cuda")/20.0

    print('Epoch: [{epoch}]\t'
          'Historic Data Quantile: {historic_data_quantile}\t'
          'Historic Data Mean: {historic_data_mean}\t'
          'Historic Count Quantile: {historic_count_quantile}\t'
          'Historic Count Mean: {historic_count_mean}\t'
        .format(epoch=epoch,
                historic_data_quantile=torch.quantile(sample_approx_grad_sq, quantiles), historic_data_mean=sample_approx_grad_sq.mean(),
                historic_count_quantile=torch.quantile(train_state.history_count.float(), quantiles), historic_count_mean=train_state.history_count.float().mean()
        ))

    print('Epoch: [{epoch}]\t'
        'Selected: {selected_count}\t'
        'All: {all_count}\t'
        'Filtered: {filter_count}\t'
        'Epoch Max Scaling: {epoch_max_scaling}\t'
        'Max Scaling: {max_scaling}\t'
        'Max Scaling All: {max_scaling_all}\t'
        'Heuristic Ratio: {heuristic_ratio}\t'
        'Smooth Heuristic Ratio: {smooth_heuristic_ratio}\t'
        'Train Loss: {loss.avg:.4f}\t'
        'Train Prec@1: {top1.avg:.3f}'
        .format(epoch=epoch, selected_count=selected_count, all_count=all_count, filter_count=filter_count, epoch_max_scaling=epoch_max_scaling, max_scaling=train_state.max_scaling, max_scaling_all=train_state.max_scaling_all, heuristic_ratio=heuristic_ratio, smooth_heuristic_ratio=smooth_heuristic_ratio, loss=losses, top1=top1))


def validate(val_loader, model, criterion):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i+1) % args.eval_print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print('Validation Result * Prec@1 {top1.avg:.3f} Loss {loss.avg:.4f}'
          .format(top1=top1, loss=losses))

    return top1.avg, losses.avg

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


if __name__ == '__main__':
    main()
