import collections
import datetime
from fractions import Fraction
from functools import partial
import itertools
import math
from pprint import pprint
import sys
import time
from typing import Any, Callable, Dict, Mapping, Tuple

from absl import app
from absl import flags
from absl import logging
import fancyflags as ff
from flax import linen as nn
import jax
from jax import lax
from jax import numpy as jnp
from jax import random
from jax.tree_util import tree_reduce, tree_map, tree_leaves, tree_flatten, tree_unflatten
from ml_collections import ConfigDict
import numpy as np
import optax
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import wandb

import image_jax
import resnet_jax
import util
import vgg_jax

flags.DEFINE_string('dataset', 'cifar10', 'Dataset to use.')
flags.DEFINE_string('input_stats', 'auto', 'Normalization statistics (or "auto" to match dataset).')
flags.DEFINE_string('dataset_root', 'data', 'Path to data.')
flags.DEFINE_bool('download', False, 'Download dataset.')
flags.DEFINE_integer('eval_batch_size', 128, 'Batch size to use during evaluation.')
flags.DEFINE_integer('num_workers', 4, 'num_workers for DataLoader')
flags.DEFINE_integer('prefetch_factor', 2, 'prefetch_factor for DataLoader')
flags.DEFINE_integer('log_interval', 20, 'Period for logging.')
flags.DEFINE_integer('power_interval', 200, 'Period for power method.')
flags.DEFINE_integer('power_batch_size', 1000, 'Maximum batch size for power iterations.')
flags.DEFINE_integer('seed', 0, 'Random seed.')

ff.DEFINE_dict(
    'model',
    arch=ff.String('resnet_v1_18'),
    bn=ff.Boolean(True),
    res_init_zero=ff.Boolean(True))
ff.DEFINE_dict(
    'train',
    batch_size=ff.Integer(128),
    num_epochs=ff.Integer(90),
    learning_rate=ff.Float(0.1),
    schedule=ff.String('const'),  # const, cos, piece
    schedule_step_epochs=ff.String('50,80'),
    schedule_step_scale=ff.Float(0.1),
    momentum=ff.Float(0.9),
    momentum_type=ff.String('polyak'),
    loss=ff.String('ce'),
    ce_smooth=ff.Float(0.0),
    mse_scale=ff.Float(1.0),
    weight_decay=ff.Float(0.0),
    weight_decay_vars=ff.String('kernel'),  # kernel, all
    mixup_beta=ff.Float(0.0),  # 0 => no mixup; 1 => uniform mixup; inf => 0.5 mixup
    aug=ff.String('cifar'),  # cifar
    aug_prob=ff.Float(0.0, 'Augmentation probability.'),
    gd=ff.Boolean(False, 'Full-batch gradient descent (takes subset of dataset).'),
    accum_steps=ff.Integer(0, 'Accumulate gradients over multiple steps. Disabled if zero.'),
    # TODO: Change to stop_train_err (better for floats).
    stop_train_err=ff.Float(-1.0, 'Terminate when train_err <= stop_train_err (non-strict).'),
    sam_rho=ff.Float(0.0, 'Parameter rho in Smoothness Aware Minimization.'))
# ff.DEFINE_auto('power', power_method)
ff.DEFINE_dict(
    'power',
    num_iters=ff.Integer(100),
    rtol=ff.Float(100),
    atol=ff.Float(100))

FLAGS = flags.FLAGS

CONFIG_FLAGS = [
    'seed',
    'dataset',
    'input_stats',
    'model',
    'train',
    'power',
    'power_batch_size',
    'eval_batch_size',  # Should not affect anything.
    'log_interval',  # Affects appearance of curves.
    'power_interval',  # Affects appearance of curves.
    'num_workers',  # Cannot guarantee deterministic if > 1.
]

DERIVATIVES_MIN_MAX = [
    'hessian',
]

DERIVATIVES_SIMPLE = [
    'gauss_newton',
    'ntk',
    'jacobian_jjt',
    'jacobian_jtj',
    'jacobian_eval_jjt',
    'jacobian_eval_jtj',
]

Dataset = torch.utils.data.Dataset
ModuleDef = Callable[..., nn.Module]


class OptimizationError(Exception):
    pass


def merge_configs():
    """Must be called after wandb.init(). Modifies FLAGS and wandb.config."""
    assert_flat(wandb.config)
    flags_config = ConfigDict({k: getattr(FLAGS, k) for k in CONFIG_FLAGS})
    flags_config.lock()  # Disallow adding or removing options.
    print('original flags_config:')
    pprint(flags_config.to_dict())
    # Update wandb.config using flags.
    # For a sweep, this sets unspecified params and does not override specified params.
    # Use flat dict because wandb does not update recursively.
    wandb.config.update(dict(flatten_items(flags_config.items(), sep='.')))
    assert_flat(wandb.config)
    print('wandb.config:')
    pprint(wandb.config.as_dict())
    # Update config using wandb's config (gets params set by sweep).
    # (Not necessary if all params were passed as flags.)
    # Raises if new params were specified.
    flags_config.update_from_flattened_dict(wandb.config.as_dict())
    print('final flags_config:')
    pprint(flags_config.to_dict())
    # Modify FLAGS in-place to avoid confusion.
    for k, v in flags_config.items():
        setattr(FLAGS, k, v)


def main(_):
    wandb.init()
    merge_configs()
    try:
        train()
    except OptimizationError as ex:
        logging.warning('optimization failed: %s', str(ex))


def train():
    rng = random.PRNGKey(FLAGS.seed)
    rng, rng_init, rng_data, rng_next_step, rng_torch, rng_power = random.split(rng, 6)
    torch.manual_seed(torch_seed(rng_torch))  # Should have no effect but just to be safe.

    rng_data, rng_train_data, rng_power_data = random.split(rng_data, 3)
    rng_data, rng_train_subset, rng_val_subset = random.split(rng_data, 3)
    num_classes, input_shape, (train_dataset, val_dataset) = setup_data()
    if FLAGS.train['gd']:
        # Take a batch-sized subset of each dataset.
        # Raises if batch_size is greater than len(dataset).
        # TODO: Load entire dataset into (GPU?) memory.
        train_dataset = torch.utils.data.Subset(train_dataset, np.array(random.choice(
            rng_train_subset, len(train_dataset), shape=(FLAGS.train['batch_size'],), replace=False)))
        val_dataset = torch.utils.data.Subset(val_dataset, np.array(random.choice(
            rng_val_subset, len(val_dataset), shape=(FLAGS.train['batch_size'],), replace=False)))
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=FLAGS.train['batch_size'],
        shuffle=True,
        pin_memory=False,
        num_workers=FLAGS.num_workers,
        prefetch_factor=FLAGS.prefetch_factor,
        generator=torch_generator(rng_train_data),
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=FLAGS.eval_batch_size,
        shuffle=False,
        pin_memory=False,
        num_workers=FLAGS.num_workers,
        prefetch_factor=FLAGS.prefetch_factor,
        drop_last=False)
    # Separate DataLoader for the power method (from train dataset).
    assert FLAGS.power_batch_size <= len(train_dataset)
    power_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=FLAGS.power_batch_size,
        shuffle=True,
        pin_memory=False,
        num_workers=FLAGS.num_workers,
        prefetch_factor=FLAGS.prefetch_factor,
        generator=torch_generator(rng_power_data),
        drop_last=True)

    power_data_iter = iter(infinite_loader(power_loader))
    # power_data_iter = InfiniteIter(power_loader)

    norm = nn.BatchNorm
    norm_kwargs = lambda train: {'use_running_average': not train}

    model = make_model(FLAGS.model, num_classes, input_shape, norm=norm)
    init_vars = model.init(rng_init, jnp.zeros((1,) + input_shape), norm_kwargs=norm_kwargs(train=True))
    init_vars, params = init_vars.pop('params')
    mutable_vars = init_vars
    del init_vars
    dummy_mutable_vars = tree_map(lambda x: jnp.full_like(x, jnp.nan), mutable_vars)

    sample_vec_params = lambda rng: normal_like_tree(rng, params)
    sample_vec_inputs = lambda rng: random.normal(rng, (FLAGS.power_batch_size, *input_shape))
    sample_vec_outputs = lambda rng: random.normal(rng, (FLAGS.power_batch_size, num_classes))
    sample_vec_fns = {
        'hessian': sample_vec_params,
        'hessian_min': sample_vec_params,
        'hessian_max': sample_vec_params,
        'gauss_newton': sample_vec_params,
        'ntk': sample_vec_outputs,
        'jacobian_jjt': sample_vec_outputs,
        'jacobian_eval_jjt': sample_vec_outputs,
        'jacobian_jtj': sample_vec_inputs,
        'jacobian_eval_jtj': sample_vec_inputs,
    }
    vec_rngs = dict(zip(sample_vec_fns, random.split(rng_power, len(sample_vec_fns))))
    vecs = {k: sample_vec_fns[k](vec_rngs[k]) for k in itertools.chain(
        *[(op, f'{op}_min', f'{op}_max') for op in DERIVATIVES_MIN_MAX],
        DERIVATIVES_SIMPLE,
    )}

    steps_per_epoch = len(train_loader)
    total_steps = FLAGS.train['num_epochs'] * steps_per_epoch
    print('steps per epoch:', steps_per_epoch)
    print('total number of steps:', total_steps)
    print('total number of params:', sum(map(lambda x: math.prod(x.shape), tree_leaves(params))))
    print('number of linear layers:', sum(1 for x in tree_leaves(params) if x.ndim > 1))

    # Optimizer schedule uses number of multisteps not batches.
    multistep_size = max(1, FLAGS.train['accum_steps'])
    make_schedule = {
        'const': optax.constant_schedule,
        'cos': partial(optax.cosine_decay_schedule, decay_steps=ceildiv(total_steps, multistep_size)),
        'piece': partial(optax.piecewise_constant_schedule, boundaries_and_scales={
            ceildiv(int(n) * steps_per_epoch, multistep_size): FLAGS.train['schedule_step_scale']
            for n in FLAGS.train['schedule_step_epochs'].split(',')}),
    }[FLAGS.train['schedule']]
    lr_schedule = make_schedule(FLAGS.train['learning_rate'])
    momentum = FLAGS.train['momentum'] or None
    nesterov = {'polyak': False, 'nesterov': True}[FLAGS.train['momentum_type']]
    tx = optax.inject_hyperparams(optax.sgd)(lr_schedule, momentum=momentum, nesterov=nesterov)
    if FLAGS.train['accum_steps']:
        tx = optax.MultiSteps(tx, FLAGS.train['accum_steps'], use_grad_mean=True)
    opt_state = tx.init(params)

    loss_fn = {
        'ce': jax.vmap(partial(ce_loss, alpha=FLAGS.train['ce_smooth'])),
        'mse': jax.vmap(partial(mse_loss, scale=FLAGS.train['mse_scale'])),
    }[FLAGS.train['loss']]

    aug_fn = {
        'cifar': cifar_aug,
    }[FLAGS.train['aug']]
    if not FLAGS.train['aug_prob']:
        # Augmentation will never be used. Avoid computation.
        aug_fn = lambda _, im: im

    @jax.jit
    def step_obj_fn(params, mutable_vars, data, rng_step, mixup_beta: float, weight_decay: float, aug_prob: float):
        # Designed for use with jax.value_and_grad(..., has_aux=True).
        # First arg (wrt which derivative is taken) is params.
        # Returns scalar loss and one auxiliary output.
        inputs, targets = data
        model_vars = {'params': params, **mutable_vars}
        rng_mixup, rng_aug = random.split(rng_step)
        # Apply augmentation with probability.
        rng_aug, rng_use_aug = random.split(rng_aug)
        aug_inputs = aug_fn(rng_aug, inputs)
        use_aug = random.bernoulli(rng_use_aug, aug_prob, inputs.shape[:1])
        # Select between augmented and unaugmented along example dim.
        aug_inputs = jax.vmap(lax.select)(use_aug, aug_inputs, inputs)
        inputs = lax.select(aug_prob != 0, aug_inputs, inputs)
        # Apply mixup.
        # TODO: Could mixup with different augmentations for each.
        mixup_inputs, mixup_targets = mixup(rng_mixup, mixup_beta, (inputs, targets))
        inputs = lax.select(mixup_beta != 0, mixup_inputs, inputs)
        targets = lax.select(mixup_beta != 0, mixup_targets, targets)
        # Evaluate model.
        outputs, mutated_vars = model.apply(
            model_vars, inputs, norm_kwargs=norm_kwargs(train=True),
            mutable=list(mutable_vars.keys()))
        example_loss = loss_fn(targets, outputs)
        wd_loss = weight_decay_fn(params)
        objective = jnp.mean(example_loss)
        objective = lax.select(weight_decay == 0, objective, objective + weight_decay * wd_loss)
        return objective, (mutated_vars, outputs, example_loss)

    def weight_decay_fn(params):
        assert FLAGS.train['weight_decay_vars'] in ('kernel', 'all')
        wd_vars = list(tree_leaves(params))
        if FLAGS.train['weight_decay_vars'] == 'kernel':
            wd_vars = [x for x in wd_vars if x.ndim > 1]
        return 0.5 * sum([jnp.sum(jnp.square(x)) for x in wd_vars])

    @jax.jit
    def train_step(opt_state, params, mutable_vars, data, rng_step):
        """Returns updates for optimizer, params, mutable_vars."""
        # Take derivative wrt first arg (params).
        obj_value_and_grad_fn = jax.value_and_grad(step_obj_fn, has_aux=True)
        kwargs = dict(
            mixup_beta=FLAGS.train['mixup_beta'],
            weight_decay=FLAGS.train['weight_decay'],
            aug_prob=FLAGS.train['aug_prob'])

        (objective, aux), grads = obj_value_and_grad_fn(params, mutable_vars, data, rng_step, **kwargs)
        mutated_vars, outputs, example_loss = aux
        if FLAGS.train['sam_rho']:
            # (objective, aux), grads = obj_value_and_grad_fn(params, mutable_vars, data, rng_step, **kwargs)
            grads = tree_normalize(grads)
            noised_params = tree_map(lambda a, b: a + FLAGS.train['sam_rho'] * b, params, grads)
            _, grads = obj_value_and_grad_fn(noised_params, mutable_vars, data, rng_step, **kwargs)

        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return opt_state, params, mutated_vars, (objective, outputs, example_loss)

    @jax.jit
    def eval_model_fn(params, mutable_vars, inputs):
        # No need to pass mutable=['batch_stats']; stats are not "updated".
        return model.apply(
            {'params': params, **mutable_vars}, inputs,
            norm_kwargs=norm_kwargs(train=False))

    @jax.jit
    def train_model_fn(params, inputs):
        """Evaluates the model using batch-norm in online mode."""
        # Unfortunately, still need to pass batch_stats.
        # When use_running_average=False, nn.BatchNorm returns updates for batch_stats.
        # Here these updates are not used.
        outputs, _ = model.apply(
            {'params': params, **dummy_mutable_vars}, inputs,
            mutable=['batch_stats'],
            norm_kwargs=norm_kwargs(train=True))
        return outputs

    # def obj_fn(inputs, targets, params):  # Without batch_stats.
    #     return jnp.mean(loss_fn(targets, train_model_fn(params, inputs)))

    def obj_fn(data, params):
        """Evaluates data loss using batch-norm in online mode.

        Disables mixup, weight decay, augmentation.
        """
        # Provide dummy batch_stats and rng (both unused).
        dummy_rng = random.PRNGKey(0)
        objective, _ = step_obj_fn(
            params, dummy_mutable_vars, data, dummy_rng,
            mixup_beta=0.0, weight_decay=0.0, aug_prob=0.0)
        return objective

    @jax.jit
    def obj_hvp(inputs, targets, params, d_params):
        # (D_w D_w L.F)
        return tree_hvp(partial(obj_fn, (inputs, targets)), params, d_params)

    @jax.jit
    def obj_hvp_add_id(inputs, targets, params, b, d_params):
        return func_add_id(partial(obj_hvp, inputs, targets, params), b)(d_params)

    @jax.jit
    def obj_gnvp(inputs, targets, params, d_params):
        # (D_w F)' (D^2 L) (D_w F)
        f = lambda outputs: jnp.mean(loss_fn(targets, outputs))
        g = lambda params: train_model_fn(params, inputs)
        return tree_gnvp(f, g, params, d_params)

    @jax.jit
    def model_inputs_jtjvp(params, inputs, d_inputs):
        # (D_x F)' (D_x F)
        return jtjvp(partial(train_model_fn, params), inputs, d_inputs)

    @jax.jit
    def model_inputs_jjtvp(params, inputs, d_outputs):
        # (D_x F) (D_x F)'
        return jjtvp(partial(train_model_fn, params), inputs, d_outputs)

    @jax.jit
    def model_inputs_jtjvp_eval(params, mutable_vars, inputs, d_inputs):
        # (D_x F)' (D_x F)
        return jtjvp(partial(eval_model_fn, params, mutable_vars), inputs, d_inputs)

    @jax.jit
    def model_inputs_jjtvp_eval(params, mutable_vars, inputs, d_outputs):
        # (D_x F) (D_x F)'
        return jjtvp(partial(eval_model_fn, params, mutable_vars), inputs, d_outputs)

    @jax.jit
    def model_kvp(inputs, params, d_outputs):  # or model_params_jjtvp
        # (D_w F) (D_w F)'
        return tree_jjtvp(lambda params: train_model_fn(params, inputs), params, d_outputs)

    def derivative_fn(name, inputs, targets, params, mutable_vars):
        if name == 'hessian':
            return partial(obj_hvp, inputs, targets, params)
        elif name == 'gauss_newton':
            return partial(obj_gnvp, inputs, targets, params)
        elif name == 'ntk':
            return partial(model_kvp, inputs, params)
        elif name == 'jacobian_jjt':
            return partial(model_inputs_jjtvp, params, inputs)
        elif name == 'jacobian_jtj':
            return partial(model_inputs_jtjvp, params, inputs)
        elif name == 'jacobian_eval_jjt':
            return partial(model_inputs_jjtvp_eval, params, mutable_vars, inputs)
        elif name == 'jacobian_eval_jtj':
            return partial(model_inputs_jtjvp_eval, params, mutable_vars, inputs)
        else:
            raise ValueError('unknown operator', name)

    def derivative_add_id_fn(name, inputs, targets, params, mutable_vars):
        if name == 'hessian':
            return partial(obj_hvp_add_id, inputs, targets, params)
        else:
            raise ValueError('unknown operator', name)

    epoch_outputs = collections.defaultdict(list)
    interval_outputs = collections.defaultdict(list)
    epoch = 0  # Number of epochs completed.
    step = 0  # Number of steps completed.
    metrics = {}
    train_iter = None
    batch_data = None
    start_time = time.time()
    train_err = 1.0
    optimization_error = None  # Indicates divergence.
    frac_multistep = Fraction(0)  # Fractional multistep.

    while True:
        # Pre-step metrics (compute after step *and* at init).
        if train_iter is None and not optimization_error:  # At init or just finished an epoch.
            val_outputs = collections.defaultdict(list)
            for inputs, labels in tqdm(val_loader, f'val epoch {epoch}'):
                inputs, labels = torch2jax((inputs, labels))
                targets = nn.one_hot(labels, num_classes)
                scores = eval_model_fn(params, mutable_vars, inputs)
                loss = loss_fn(targets, scores)
                pred = jnp.argmax(scores, axis=-1)
                acc = jnp.equal(pred, labels)
                err = jnp.not_equal(pred, labels)
                batch_outputs = {'loss': loss, 'acc': acc, 'err': err}
                for k, v in batch_outputs.items():
                    val_outputs[k].append(v)
            metrics.update({'epoch/train/' + k: _mean_concat(v) for k, v in epoch_outputs.items()})
            metrics.update({'epoch/val/' + k: _mean_concat(v) for k, v in val_outputs.items()})
            elapsed = datetime.timedelta(seconds=int(time.time() - start_time))
            msg = []
            msg.append('val acc {:.2%}'.format(metrics['epoch/val/acc']))
            if epoch > 0:
                msg.append('train acc {:.2%}'.format(metrics['epoch/train/acc']))
                msg.append('train obj {:.6g}'.format(metrics['epoch/train/objective']))
                msg.append(f'elapsed {elapsed!s}')
            print(f'epoch {epoch:d}: ' + ', '.join(msg), file=sys.stderr)
            if epoch > 0:
                train_err = metrics['epoch/train/err']
            # Clear buffer.
            epoch_outputs = collections.defaultdict(list)

        if frac_multistep % FLAGS.power_interval == 0 and not optimization_error:
            inputs, labels = torch2jax(next(power_data_iter))
            targets = nn.one_hot(labels, num_classes)
            vals = {}
            with util.timer('all eigenvalues'):
                for op in DERIVATIVES_MIN_MAX:
                    op_min, op_max = f'{op}_min', f'{op}_max'
                    with util.timer(op):
                        ((vals[op], vecs[op]),
                         (vals[op_min], vecs[op_min]),
                         (vals[op_max], vecs[op_max])) = power_method_min_max(
                             derivative_fn(op, inputs, targets, params, mutable_vars),
                             derivative_add_id_fn(op, inputs, targets, params, mutable_vars),
                             vecs[op], vecs[op_min], vecs[op_max], **FLAGS.power)
                for op in DERIVATIVES_SIMPLE:
                    with util.timer(op):
                        vals[op], vecs[op] = power_method(
                            derivative_fn(op, inputs, targets, params, mutable_vars), vecs[op], **FLAGS.power)
            if 'ntk' in vals:
                # Scale NTK by sqrt of loss Hessian (assuming MSE loss).
                ntk_scale = 2 / (FLAGS.power_batch_size * num_classes)
                vals['ntk_scaled'] = ntk_scale * vals['ntk']
            for k in vals:
                if k.startswith('jacobian_'):
                    vals[k] = jnp.sqrt(jnp.abs(vals[k]))
            vals = {k: v.item() for k, v in vals.items()}
            # pprint(vals)
            metrics.update({'train/' + k: v for k, v in vals.items()})

        if metrics:
            # Use ceil in case training failed in the middle of a multi-step.
            wandb.log(metrics, step=math.ceil(frac_multistep))
        if optimization_error:
            raise optimization_error  # After logging step metrics.
        if (not epoch < FLAGS.train['num_epochs']) or train_err <= FLAGS.train['stop_train_err']:
            break
        metrics = {}
        if train_iter is None:
            train_iter = iter(tqdm(train_loader, f'train epoch {epoch}'))
            batch_data = next(train_iter)   # Should not raise.

        # Take step.
        # Examples for computing update.
        inputs, labels = torch2jax(batch_data)
        targets = nn.one_hot(labels, num_classes)
        rng_next_step, rng_step = random.split(rng_next_step)
        need_extra_eval = bool(FLAGS.train['mixup_beta'] or FLAGS.train['aug_prob'])
        # Examples for evaluation.
        scores, loss = None, None
        if need_extra_eval:
            # Need an extra eval without mixup/augmentation (before update to params).
            # Evaluate model without mixup or data aug.
            scores = train_model_fn(params, inputs)
            loss = loss_fn(targets, scores)
        # Take step (including mixup within batch).
        opt_lr = _get_opt_lr(opt_state)
        opt_state, params, mutable_vars, (objective, step_scores, step_loss) = train_step(
            opt_state, params, mutable_vars, (inputs, targets), rng_step)
        if not need_extra_eval:
            scores, loss = step_scores, step_loss
        if not jnp.isfinite(objective):
            if step == 0:
                raise ValueError(f'objective not finite at init: {objective!s}')
            # Do not raise yet. Allow for objective to be logged.
            optimization_error = OptimizationError(f'objective not finite: {objective!s}')
        pred = jnp.argmax(scores, axis=-1)
        acc = jnp.equal(pred, labels)
        err = jnp.not_equal(pred, labels)
        step_outputs = {'loss': loss, 'acc': acc, 'err': err, 'objective': objective, 'opt_lr': opt_lr}
        for k, v in step_outputs.items():
            epoch_outputs[k].append(v)
            interval_outputs[k].append(v)

        step += 1  # Step completed.
        frac_multistep = Fraction(step, multistep_size)
        try:
            batch_data = next(train_iter)
        except StopIteration:
            epoch += 1  # Epoch completed.
            train_iter = None
            batch_data = None

        # Post-step metrics.
        # Log at each interval and on optimization failure.
        if frac_multistep % FLAGS.log_interval == 0 or optimization_error:
            metrics.update({'train/' + k: _mean_concat(v) for k, v in interval_outputs.items()})
            # Clear buffer.
            interval_outputs = collections.defaultdict(list)


def setup_data() -> Tuple[int, Tuple[int, int, int], Tuple[Dataset, Dataset]]:
    num_classes = {
        'cifar10': 10,
        'cifar100': 100,
    }[FLAGS.dataset]
    input_shape = {
        'cifar10': (32, 32, 3),
        'cifar100': (32, 32, 3),
    }[FLAGS.dataset]

    input_stats = {
        'cifar10': ((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
        'cifar100': ((0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.2762)),
    }[FLAGS.dataset if FLAGS.input_stats == 'auto' else FLAGS.input_stats]

    # transform_train = transforms.Compose([
    #     transforms.RandomCrop(32, padding=4),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Normalize(0.5, 1.0),
    # ])
    # transform_eval = transforms.Compose([
    #     transforms.ToTensor(),
    #     transforms.Normalize(0.5, 1.0),
    # ])
    transform_eval = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(*input_stats),
    ])
    transform_train = transform_eval

    make_dataset = {
        'cifar10': partial(datasets.CIFAR10, FLAGS.dataset_root, download=FLAGS.download),
        'cifar100': partial(datasets.CIFAR100, FLAGS.dataset_root, download=FLAGS.download),
    }[FLAGS.dataset]
    train_dataset = make_dataset(train=True, transform=transform_train)
    val_dataset = make_dataset(train=False, transform=transform_eval)

    return num_classes, input_shape, (train_dataset, val_dataset)


def make_model(
        config: dict,
        num_classes: int,
        input_shape: Tuple[int, int, int],
        norm: ModuleDef = nn.BatchNorm) -> nn.Module:
    try:
        make_fn = {
            'resnet_v1_18': resnet_jax.ResNet18,
            'vgg11': vgg_jax.VGG11,
        }[config['arch']]
    except KeyError as ex:
        raise ValueError('unknown architecture', ex)

    kwargs = {}
    if 'resnet' in config['arch']:
        kwargs.update(no_init_zero=not config['res_init_zero'])

    return make_fn(num_classes=num_classes, norm=norm, no_bn=not config['bn'], **kwargs)


# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
def hvp(f, primals, tangents):
    return jax.jvp(jax.grad(f), primals, tangents)[1]


def tree_hvp(f, x, v):
    """Permits x and v to be PyTrees."""
    x, treedef = tree_flatten(x)
    v = tree_flatten_check(v, treedef)
    # Define with single param for jax.grad().
    f_flat = lambda x_flat: f(tree_unflatten(treedef, x_flat))
    hv_flat = hvp(f_flat, (x,), (v,))
    return tree_unflatten(treedef, hv_flat)


def func_add_id(f, b):
    def g(x):
        # (f + b*id)(x)
        return tree_map(lambda u, v: u + b*v, f(x), x)
    return g


# def loss_hvp(targets, x, dx):
#     return hvp(lambda x: jnp.mean(loss_fn(targets, x)), (x,), (dx,))


def tree_gnvp(f, g, x, v):
    x, treedef = tree_flatten(x)
    v = tree_flatten_check(v, treedef)
    # Define with flat args list for jax.vjp().
    g_flat = lambda *x_flat: g(tree_unflatten(treedef, x_flat))
    y, jv = jax.jvp(g_flat, x, v)
    hjv = hvp(f, (y,), (jv,))
    _, vjpfun = jax.vjp(g_flat, *x)
    jhjv_flat = vjpfun(hjv)
    return tree_unflatten(treedef, jhjv_flat)


def jjtvp(f, x, v):
    _, vjpfun = jax.vjp(f, x)
    jtv, = vjpfun(v)
    _, jjtv = jax.jvp(f, (x,), (jtv,))
    return jjtv


def jtjvp(f, x, v):
    _, jv = jax.jvp(f, (x,), (v,))
    _, vjpfun = jax.vjp(f, x)
    jtjv, = vjpfun(jv)
    return jtjv


def tree_jjtvp(f, x, v):
    x, treedef = tree_flatten(x)
    # Define with flat args list for jax.vjp().
    f_flat = lambda *x_flat: f(tree_unflatten(treedef, x_flat))
    _, vjpfun = jax.vjp(f_flat, *x)
    jtv = vjpfun(v)
    _, jjtv = jax.jvp(f_flat, x, list(jtv))
    return jjtv


def power_method(mul_fn, v, *, num_iters: int = 100, rtol: float = 1e-3, atol: float = 1e-6):
    prev_val = None
    for i in range(num_iters):
        w = mul_fn(v)
        val = sum(tree_leaves(tree_map(jnp.vdot, v, w)))
        # TODO: Could use w_norm here?
        w = tree_normalize_safe(w, v)
        if prev_val is not None:
            if abs(prev_val - val) / (abs(prev_val) + atol) < rtol:
                break
        v = w
        prev_val = val
    return val, w


def power_method_min_max(mul_fn, mul_add_fn, vec, vec_min, vec_max, safety=1.5, **kwargs):
    val, vec = power_method(mul_fn, vec, **kwargs)
    if val == 0:
        val_min, vec_min = val, vec
        val_max, vec_max = val, vec
    else:
        c = -safety * val  # Add c (subtract largest eigenvalue).
        vec_other = vec_min if val > 0 else vec_max
        val_other_plus_c, vec_other = power_method(partial(mul_add_fn, c), vec_other, **kwargs)
        val_other = val_other_plus_c - c
        if val > 0:
            val_max, vec_max = val, vec
            val_min, vec_min = val_other, vec_other
        else:
            val_min, vec_min = val, vec
            val_max, vec_max = val_other, vec_other
    return (val, vec), (val_min, vec_min), (val_max, vec_max)


def tree_normalize(w):
    w_norm = tree_norm(w)
    return tree_map(lambda x: (1. / w_norm) * x, w)


def tree_normalize_safe(w, prev_w):
    w_norm = tree_norm(w)
    w = tree_map(lambda x: (1. / w_norm) * x, w)
    return tree_map(
        lambda x, prev_x: lax.select(w_norm == 0, prev_x, x),
        w, prev_w)


def tree_norm(w):
    return jnp.sqrt(sum(tree_leaves(tree_map(sum_squares, w))))


def sum_squares(x, axis=None):
    return jnp.sum(jnp.square(x), axis=axis)


def normal_like_tree(rng, tree):
    """Makes a random normal vector with given PyTree."""
    leaves, treedef = tree_flatten(tree)
    leaf_shapes = list(map(jnp.shape, leaves))
    rngs = random.split(rng, len(leaf_shapes))
    x = [random.normal(r, shape) for r, shape in zip(rngs, leaf_shapes)]
    return tree_unflatten(treedef, x)


def ceildiv(a, b):
    return math.ceil(Fraction(a, b))


def torch2jax(data):
    inputs, labels = data
    inputs, labels = jnp.asarray(inputs.numpy()), jnp.asarray(labels.numpy())
    inputs = jnp.moveaxis(inputs, -3, -1)
    return inputs, labels


def ce_loss(targets, scores, alpha=0.0):
    targets = optax.smooth_labels(targets, alpha)
    return optax.softmax_cross_entropy(logits=scores, labels=targets)


def mse_loss(targets, scores, scale=1.0):
    return jnp.mean(jnp.square(scores - scale * targets))


def torch_seed(rng: random.PRNGKey) -> torch.Generator:
    # 64-bit dtypes (jnp.int64) are not enabled by default.
    # https://github.com/google/jax#current-gotchas
    # minval, maxval = -0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff
    minval, maxval = -0x8000_0000, 0x7fff_ffff
    return random.randint(rng, (), minval, maxval).item()


def torch_generator(rng: random.PRNGKey) -> torch.Generator:
    seed = torch_seed(rng)
    g = torch.Generator()
    g.manual_seed(seed)
    return g


def flatten_items(items, sep: str = '.', prefix: str = '') -> Dict[str, Any]:
    for k, v in items:
        if hasattr(v, 'items'):
            yield from flatten_items(v.items(), sep=sep, prefix=f'{prefix}{k}{sep}')
        else:
            yield f'{prefix}{k}', v


def assert_flat(mapping: Mapping):
    child_mappings = [k for k, v in mapping.items() if isinstance(v, Mapping)]
    if child_mappings:
        raise ValueError('wandb.config contains nested parameters', child_mappings)


def mixup(rng, beta, data):
    batch_size = tree_batch_size(data)
    theta, perm = mixup_sample(rng, beta, batch_size)
    return tree_map(partial(mixup_apply, theta, perm), data)


def tree_batch_size(tree):
    sizes = [x.shape[0] for x in tree_leaves(tree)]
    try:
        size, = set(sizes)
    except ValueError:
        raise ValueError('no unique batch size', sizes)
    return size


def mixup_sample(rng, beta: float, batch_size: int, dtype=jnp.float32) -> Tuple[jnp.ndarray, jnp.ndarray]:
    rng_theta, rng_perm = random.split(rng)
    # random.beta(0, 0) returns nan instead of Bernoulli
    theta = lax.select(
        beta == 0,
        random.bernoulli(rng_theta, 0.5, (batch_size,)).astype(dtype),
        random.beta(rng_theta, beta, beta, (batch_size,), dtype=dtype))
    theta = jnp.maximum(theta, 1.0 - theta)
    perm = random.permutation(rng_perm, batch_size)
    return theta, perm


def mixup_apply(theta: jnp.ndarray, perm: jnp.ndarray, arr: jnp.ndarray) -> jnp.ndarray:
    # TODO: Might be better to use vmap?
    return (jnp.einsum('i,i...->i...', theta, arr) +
            jnp.einsum('i,i...->i...', 1.0 - theta, arr[perm]))


def tree_flatten_check(tree, treedef, **kwargs):
    tree, new_treedef = tree_flatten(tree, **kwargs)
    assert new_treedef == treedef
    return tree


def tree_flatten_same(trees, **kwargs):
    trees = list(trees)
    treedef = None
    for i in range(len(trees)):
        trees[i], curr = tree_flatten(trees[i], **kwargs)
        assert curr is not None
        if i == 0:
            treedef = curr
        else:
            assert curr == treedef
    return tuple(trees), treedef


def infinite_loader(loader):
    while True:
        yield from loader


class InfiniteIter:

    def __init__(self, loader):
        self.loader = loader
        self.it = None

    def __next__(self):
        """Avoid calling iter() until necessary."""
        if self.it is None:
            self.it = iter(self.loader)
        try:
            return next(self.it)
        except StopIteration:
            self.it = iter(self.loader)
        return next(self.it)


def cifar_aug(rng, im):
    rng_crop, rng_flip = random.split(rng)
    im = image_jax.random_pad_crop(rng_crop, im, fill=0.0, padding=(4, 4), size=(32, 32))
    im = image_jax.random_flip(rng_flip, im, axis=2)
    return im


def _get_opt_lr(opt_state):
    if isinstance(opt_state, optax.MultiStepsState):
        opt_state = opt_state.inner_opt_state
    return opt_state.hyperparams['learning_rate']


def _concat_maybe_scalar(values):
    return jnp.concatenate([jnp.array(x, ndmin=1, copy=False) for x in values])


def _mean_concat(values):
    return jnp.mean(_concat_maybe_scalar(values)).item()


if __name__ == '__main__':
    app.run(main)
