import jax
from .optimizers import get_one
import gc
import time
from .locking import NFSLock
from pathlib import Path
import jax.numpy as jnp
from jax.tree_util import tree_flatten
import numpy as np
import dill as pickle
from .distributed import MemCastedDDict, CastedDDict, DDict
from tqdm import tqdm
from .dlpack import dlpack_gpu2cpu, make_io_stream
from functools import partial, cache
from .jittools import set_jit, maybe_jit
from .optimizers import TrainState
set_jit(True)
jax.numpy.ndarray.__repr__ = lambda x: f'ndarray(shape={x.shape}, dtype={x.dtype})'
jax.numpy.ndarray.__str__ = lambda x: f'ndarray(shape={x.shape}, dtype={x.dtype})'
# set_jit(True)
jnp.set_printoptions(threshold=1)
import os
# username = os.environ.get('USER', 'unknown')
# jax.config.update("jax_compilation_cache_dir", f"/tmp/{username}-jax-cache")

from .chungus import f

class NotANumberError(Exception):
    pass

#region Pytree utils

@maybe_jit
def add_trees(x, y):
    print('Compiling add_trees')
    def safe_add(x, y):
        if x.dtype == jax.float0:
            return x

        if x is None or y is None:
            return None

        return x + y

    return jax.tree_util.tree_map(safe_add, x, y)

@maybe_jit
def subtract_trees(x, y):
    print('Compiling subtract_trees')
    def safe_subtract(x, y):
        if x.dtype == jax.float0:
            return x

        if x is None or y is None:
            return None

        return x - y

    return jax.tree_util.tree_map(safe_subtract, x, y)

def to_gpu(x):
    return jax.device_put(x, jax.devices('gpu')[0])

def safe_zeros(x):
    if isinstance(x, int) or x.dtype == jax.float0:
        return x

    return jnp.zeros_like(x)

@maybe_jit
def pytree_gbs(pytree):
    flat, _ = jax.tree_util.tree_flatten(pytree)
    total_size_in_bytes = sum(arr.nbytes for arr in flat)
    return total_size_in_bytes * 1e-9

def pytree_size(v):
    def get_size(x):
        if hasattr(x, 'nbytes'):
            return x.nbytes
        else:
            return 0

    sizes = jax.tree_util.tree_map(get_size, v)
    tree_sum = jax.tree_util.tree_reduce(lambda x, y: (x + y), sizes, 0)
    # return bytes in gb
    return tree_sum * 1e-9

def wipe_rep(x):
    x.__repr__ = lambda: 'per_sample_loss'
    x.__str__ = lambda: 'per_sample_loss'
    return x

def delete_state(k, state_dict):
    del state_dict[k]

def safe_diff(x, y):
    if x.dtype == jax.float0 or isinstance(x, int):
        return x

    return jnp.abs(x - y)

def pytree_diff_stats(pytree1, pytree2):
        pytree1 = pytree2.replace(params=pytree1.params, opt_state=pytree1.opt_state)
        return _pytree_diff_stats(pytree1, pytree2)

def _pytree_diff_stats(pytree1, pytree2):
    diffs = jax.tree_util.tree_map(safe_diff, pytree1, pytree2)
    flat_diffs, _ = tree_flatten(diffs)

    flat_diffs_concat = jnp.concatenate([jnp.ravel(diff) for diff in flat_diffs])
    max_diff = jnp.max(flat_diffs_concat)
    avg_diff = jnp.mean(flat_diffs_concat)
    std_diff = jnp.std(flat_diffs_concat)

    return max_diff, avg_diff, std_diff

def pytree_isnan(pytree):
    flat, _ = jax.tree_util.tree_flatten(pytree)
    for arr in flat:
        if (arr.dtype != jax.float0) and jnp.any(jnp.isnan(arr)):
            return True

    return False

def get_stage_and_step(saved_states, saved_dstates):
    if len(saved_dstates) > 0:
        print(saved_dstates)
        assert len(saved_dstates) == 1
        return 'backward', saved_dstates[0]
    else:
        assert len(saved_states) > 0
        return 'forward', max(saved_states)

def make_save_iterations(train_its, serialize_k):
    save_iterations = list(range(0, train_its, serialize_k)) + [train_its]
    save_iterations = set(save_iterations)
    return save_iterations

#endregion

def okazaki_forward(start_i, saved_states, train_its, train_batch_maker,
                    serialize_k, per_sample_loss, minibs, leave_out_samples=None,
                    leave_out_weight=None):
    if start_i == train_its:
        print('>> done training!')
        return

    state = to_gpu(saved_states[start_i])

    print('>> Training model!')
    iterator = tqdm(range(start_i, train_its))
    batch_maker = partial(batch_for_state, train_batch_maker=train_batch_maker)

    save_iterations = make_save_iterations(train_its, serialize_k)

    gpu_states = []
    old_cpu_states = None
    old_cpu_blocker = None
    stream = make_io_stream()

    def flush_and_replace(gpu_states, old_cpu_states, old_cpu_blocker):
        if old_cpu_states is not None:
            old_cpu_blocker()

            # save to disk
            for i, cpu_state in old_cpu_states:
                saved_states[i] = cpu_state

        # now move gpu_states to cpu
        if len(gpu_states) == 0:
            return [], [], lambda: None

        if old_cpu_states is not None:
            old_cpu_state_values = [v[1] for v in old_cpu_states]
            old_cpu_state_values = old_cpu_state_values[:len(gpu_states)]
        else:
            old_cpu_state_values = None

        gpu_state_values = [v[1] for v in gpu_states]
        gpu_state_indices = [v[0] for v in gpu_states]
        cpu_states, block_fn = dlpack_gpu2cpu(gpu_state_values, stream=stream,
                                              replace_buffers=old_cpu_state_values)
        old_cpu_states = list(zip(gpu_state_indices, cpu_states))
        return [], old_cpu_states, block_fn

    batch = batch_maker(start_i)

    for it in iterator:
        ### BEGIN STEP
        if it != train_its - 1:
            next_batch = batch_maker(it + 1)
        else:
            next_batch = None

        if leave_out_samples is not None and leave_out_weight is not None:
            mask = 1 - leave_out_weight * jnp.array([int(k) in leave_out_samples for k in batch[0][0]])
            this_psl = jax.tree_util.Partial(per_sample_loss, leave_out_mask=mask)
        elif leave_out_samples is not None:
            print('Leave out samples but no weight, dropping samples from batch directly')
            (ix, g1), ((x, g2, g3), (y, g4)) = batch
            mask = jnp.array([int(k) not in leave_out_samples for k in ix])
            batch = ((ix[mask], g1[:mask.sum()]), ((x[mask], g2, g3), (y[mask], g4)))
            this_psl = jax.tree_util.Partial(per_sample_loss, aug_idxs=jnp.arange(x.shape[0])[mask])
        else:
            this_psl = per_sample_loss

        state = functional_step(state, batch, this_psl, minibs=minibs)
        ### END STEP
        curr_it = it + 1
        if curr_it in save_iterations:
            gpu_states.append((curr_it, state))

        if len(gpu_states) >= 4:
            tup = flush_and_replace(gpu_states, old_cpu_states, old_cpu_blocker)
            gpu_states, old_cpu_states, old_cpu_blocker = tup

        batch = next_batch

    gpu_states, old_cpu_states, old_cpu_blocker = flush_and_replace(gpu_states,
                                                                    old_cpu_states,
                                                                    old_cpu_blocker)
    tup = flush_and_replace(gpu_states, old_cpu_states, old_cpu_blocker)
    print(f'>> Done: trained from cache @ {start_i} -> it {train_its}')
    return state

def okazaki_vjp(state, train_batch_maker, val_batch_maker, train_its, val_its,
                per_sample_loss, batch_cotangenter, model_cotangenter, save_dir,
                cache_dir, eval_every=200, optimal_k_factor=1,
                val_per_sample_loss=None,
                memory_backed=True, exit_after_forward=False, minibs=None,
                leave_out_inds=None, leave_out_weight=None, trak_eps=False,
                return_final_state=False, should_wash_state=True):
    '''
    state0: initial state; includes step
    train_batch_maker: i -> idx, (x, y)
    val_batch_maker: i -> idx, (x, y)
    train_its: num train batches
    val_its: num val batches
    per_sample_loss: (params, (idx, (x, y))) -> one loss per batch example
    batch_cotangenter: batch, losser -> eps, (eps -> new_losser, new_batch)
    model_cotangenter: final parameters -> scalar
    save_dir: where to save the final deps
    cache_dir: where to save the intermediate states
    eval_every: how often to evaluate the model
    '''
    if val_per_sample_loss is None:
        val_per_sample_loss = per_sample_loss

    if should_wash_state:
        state = wash_state(state)
    debug_mode = bool(os.environ.get('DEBUG', False))
    assert memory_backed
    #assert 'xfs' in str(save_dir)
    # TODO: enable local caching at some point

    done_path = save_dir / 'done_shibboleth'
    if done_path.exists():
        raise ValueError('>> Already done!')
        # return DDict.load_or_create(save_dir / 'deps')

    cache_dir = Path(cache_dir)

    # saved_states: i -> state at start of iteration i
    # saved_dstates: i -> dstate_i
    # saved_deps: i -> deps_i
    should_clear = not bool(os.environ.get('DEBUG', False))
    saved_states = MemCastedDDict.load_or_create(cache_dir / 'states', state,
                                                 None, clear_on_exit=should_clear)
    if exit_after_forward:
        saved_states.set_mode(True)

    saved_dstates = MemCastedDDict.load_or_create(cache_dir / 'dstates', state,
                                               None, clear_on_exit=should_clear)
    saved_dstates.set_mode(True)

    saved_deps = {}
    # saved_deps = DDict.load_or_create(save_dir / 'deps')
    if len(saved_deps) > 0 and len(saved_states) == 0:
        raise ValueError('deps but no states')

    if not 0 in saved_states:
        saved_states[0] = state

    # first need to train model + serialize every k steps
    # serialize_k = int(train_its**0.5 * optimal_k_factor)
    # serialize_k = 25
    serialize_k = int(train_its**0.5 * optimal_k_factor)
    # serialize_k = 50
    assert serialize_k > 0

    # get which state we need to start from
    curr_stage, curr_step = get_stage_and_step(saved_states.keys(),
                                               saved_dstates.keys())

    print('>> vjp stage:', curr_stage, ' @ ', curr_step)

    # also get the loss
    @partial(maybe_jit, static_argnames=('per_sample_loss', 'train'))
    def jitted_psl(*args, per_sample_loss=None, train=None, **kw):
        return per_sample_loss(*args, train=train, **kw)

    jitted_train_psl = partial(jitted_psl, per_sample_loss=partial(per_sample_loss, train=False))
    jitted_val_psl = partial(jitted_psl, per_sample_loss=partial(val_per_sample_loss, train=False))
    train_loss_evaler = partial(eval_model, val_batch_maker=val_batch_maker,
                     val_its=val_its, per_sample_loss=jitted_train_psl,
                     minibs=minibs)
    val_loss_evaler = partial(eval_model, val_batch_maker=val_batch_maker,
                     val_its=val_its, per_sample_loss=jitted_val_psl,
                     minibs=minibs)

    if curr_stage == 'forward':
        final_state = okazaki_forward(curr_step, saved_states, train_its, train_batch_maker,
                        serialize_k, per_sample_loss, minibs, 
                        leave_out_samples=leave_out_inds,
                        leave_out_weight=leave_out_weight)
        assert len(saved_dstates) == 0

    VAL_PATH = save_dir / 'val_loss.pkl'

    if train_its in saved_states:
        # final_state = to_gpu(saved_states[train_its])
        val_loss = train_loss_evaler(state=final_state, limited=debug_mode)
        val_val_loss = val_loss_evaler(state=final_state, limited=debug_mode)
        with open(VAL_PATH, 'wb') as f:
            pickle.dump(val_loss, f)
        # del final_state [TODO] - Edited
    else:
        raise ValueError('Final state not found in saved states')

    print(f'>> Trained model val loss on val set: {val_val_loss.mean():.6f}')
    print(f'>> Trained model train loss on val set: {val_loss.mean():.6f}')
    if exit_after_forward:
        if return_final_state:
            return np.array(val_loss), np.array(val_val_loss), final_state
        else:
            return np.array(val_loss), np.array(val_val_loss)

    # forward pass done: lets do the backward pass now
    # TODO: parallelize rest of code with submitit
    print(f'>> Doing okazaki...')

    # make list of start, end pairs
    save_iterations = make_save_iterations(train_its, serialize_k)
    save_iterations = sorted(list(save_iterations))
    assert save_iterations[0] == 0
    all_segments = list(zip(save_iterations[:-1], save_iterations[1:]))[::-1]

    if len(saved_dstates) == 0:
        # now calculate v wrt the final state
        # then we have to start at beginning
        state = to_gpu(saved_states[train_its])
        state_cotangents = model_cotangenter(state)
        saved_dstates[train_its] = state_cotangents
        assert save_iterations[-1] == int(state.opt_state.count)
        curr_step = train_its
    else:
        state_cotangents = to_gpu(saved_dstates[curr_step])

    del state
    # filter segments to only include those that are "after" curr_step
    segments = [seg for seg in all_segments if seg[1] <= curr_step]

    # let Jf = Jf_{okaz_N} ... Jf_{okaz_1}
    # Jf_{okaz_i} = Jf_{i * k - 1} ... Jf_{(i - 1) * k}
    batch_maker = partial(batch_for_state, train_batch_maker=train_batch_maker)
    print('>> Remaining segments', segments)

    saved_states.set_mode(memory_backed)
    assert memory_backed

    # mem_dir = Path('mem')
    # mem_dir.mkdir(exist_ok=True)

    all_depsilon_logging = []
    prev_start = None
    for start, end in tqdm(segments, desc='Okazaki stages'):
        # jax.profiler.save_device_memory_profile(f'mem/{start}_{end}.prof')
        print('DOING SEGMENT', start, end)
        assert end > start
        ret = okazaki_stage(final_i=end, start_i=start,
                            batch_maker=batch_maker,
                            per_sample_loss=per_sample_loss,
                            batch_cotangenter=batch_cotangenter,
                            state_cotangents=state_cotangents,
                            saved_states=saved_states,
                            stage_num=f'{end}/{train_its}',
                            minibs=minibs)
        state_cotangents, eps_cotangents, depsilons_logging = ret
        all_depsilon_logging.extend(depsilons_logging)

        # check eps_cotangents pytree for any nans
        if pytree_isnan(eps_cotangents):
            raise NotANumberError('NAN in eps_cotangents')

        # Old code
        # saved_deps.force_set(start, eps_cotangents)
        # saved_dstates[start] = state_cotangents
        # New code
        on_cpu = jax.device_put(eps_cotangents, jax.devices('cpu')[0])
        on_cpu = jax.tree_util.tree_map(lambda v: np.array(v), on_cpu)
        # on_cpu = {k:np.array(v) for k, v in on_cpu.items()}
        saved_dstates[start] = state_cotangents
        if not prev_start is None:
            del saved_dstates[prev_start]

        prev_start = start
        saved_deps[start] = on_cpu

    done_path.touch()
    assert len(saved_deps) >= len(all_segments)
    return saved_deps, np.array(val_loss), final_state, np.array(val_val_loss), all_depsilon_logging

def okazaki_stage(final_i, start_i, batch_maker, per_sample_loss,
                  batch_cotangenter, state_cotangents, saved_states, stage_num,
                  minibs):
    # vscode debugging purposees

    # Jf_{okaz_i} = Jf_{i * k - 1} ... Jf_{(i - 1) * k}
    # Initial state
    assert final_i > start_i
    assert start_i in saved_states

    # per_sample_loss = wipe_rep(per_sample_loss)
    batch_maker = wipe_rep(batch_maker)

    print(f'>> Backward stage: {start_i} -> {final_i}')
    state = to_gpu(saved_states[start_i])

    MAX_QUEUE_SIZE = 8
    desc = f'|stage={stage_num} | Retraining segment..'

    gpu_states = {}
    old_states_cpu = None
    stream = make_io_stream()

    def flush_to_memory(gpu_states, old_states_cpu):
        if old_states_cpu is not None:
            try:
                old_states_cpu, block_fn = old_states_cpu
            except:
                print('>osc', old_states_cpu)
                import pdb; pdb.set_trace()

            block_fn()
            # get curr stream
            for i, cpu_state in old_states_cpu.items():
                saved_states.force_set(i, cpu_state)

        # now move gpu_states to cpu
        if len(gpu_states) == 0:
            return {}, None

        gpu_keys, gpu_values = zip(*gpu_states.items())
        old_states_cpu, block_fn = dlpack_gpu2cpu(gpu_values, stream=stream)
        old_states_cpu = dict(zip(gpu_keys, old_states_cpu))
        return {}, (old_states_cpu, block_fn)

    for curr_it in tqdm(range(start_i, final_i - 1), desc=desc):
        batch = batch_maker(curr_it)
        state = functional_step(state, batch, per_sample_loss, minibs=minibs)
        state_it = curr_it + 1
        gpu_states[state_it] = state

        if len(gpu_states) > MAX_QUEUE_SIZE:
            gpu_states, old_states_cpu = flush_to_memory(gpu_states, old_states_cpu)

    gpu_states, old_states_cpu = flush_to_memory(gpu_states, old_states_cpu)
    gpu_states, old_states_cpu = flush_to_memory(gpu_states, old_states_cpu)

    del gpu_states
    del old_states_cpu

    last_prev_seen_state = int(state.opt_state.count) + 1
    del state

    backward_its = list(reversed(range(start_i, final_i)))
    assert len(backward_its) == final_i - start_i

    all_eps_cotangents = {}

    state_p1 = to_gpu(saved_states[final_i])
    state, batch = saved_states[final_i - 1], batch_maker(final_i  - 1)
    state = to_gpu(state)

    depsilons_logging = []
    for back_it in tqdm(backward_its, desc=f'|stage={stage_num} | Backward'):
        assert back_it < last_prev_seen_state
        if back_it != start_i:
            nbatch = batch_maker(back_it - 1)
            nstate = to_gpu(saved_states[back_it - 1])
            # we need start_i for next round!
            if back_it - 1 != start_i:
                del saved_states[back_it - 1]
        else:
            nbatch, nstate = [None] * 2

        eps_cotangents, state_cotangents = backward_step(batch, state, state_p1,
                                                         state_cotangents,
                                                         batch_cotangenter,
                                                         per_sample_loss=per_sample_loss,
                                                         minibs=None)
        jax.block_until_ready((eps_cotangents, state_cotangents))
        all_eps_cotangents[back_it] = eps_cotangents
        try:
            depsilons_logging.append(float(jnp.abs(eps_cotangents.reshape(-1)).mean()))
        except:
            pass

        state, batch, state_p1 = nstate, nbatch, state

    # ensure that all saved states are less than or eq to start_i
    saved_states.memkv = {}
    all_saved_states = list(saved_states.keys())
    for saved_i in all_saved_states:
        if saved_i > start_i:
            print('>> Deleting saved state', saved_i)
            del saved_states[saved_i]

    return state_cotangents, all_eps_cotangents, depsilons_logging

@jax.jit
def naive_backward_step(batch, state, _, dstate, batch_cotangenter,
                        per_sample_loss, minibs):
    num_poison = batch[1][0][1].shape[0]
    get_grads, apply_grads = factored_functional_step(use_jit=False)
    eps, get_grads, apply_grads = batch_cotangenter(per_sample_loss, 
                                                    get_grads, 
                                                    apply_grads,
                                                    num_poison=num_poison)
    def coupled_step(_eps, _state):
        grads, updates = get_grads(eps=_eps, state=_state, batch=batch)
        next_state = apply_grads(eps=_eps, grads=grads, state=state, updates=updates)
        return next_state

    return jax.vjp(coupled_step, eps, state)((dstate,))

def backward_step(batch, state, next_state, dstate, batch_cotangenter,
                  per_sample_loss, minibs):
    bs = len(batch[0])
    # if bs <= minibs:
    #     # then we can simultaneously vjp
    #     backward_batch = partial(coupled_backward, state=state,
    #                              next_state=next_state, dstate=dstate,
    #                              batch_cotangenter=batch_cotangenter,
    #                              per_sample_loss=per_sample_loss)
    #     deps, dstate = minibatch_func(backward_batch, batch, minibs)
    # else:
        # then have to:
        # (a) do vjp of grad step separately
        # (b) sum the vjp using cotangents
    num_poison = batch[1][0][1].shape[0]
    dbatch_stats = dstate.batch_stats
    deps, dgrads, dstate = backward_apply(state, next_state, dstate,
                                            batch_cotangenter,
                                            per_sample_loss,
                                            num_poison=num_poison)
    dstate = dstate.replace(batch_stats=jax.tree_util.tree_map(safe_zeros, state.batch_stats))
    minibatched_backward = partial(backward_grad, state=state,
                                    dgrads=dgrads,
                                    dupdates=dbatch_stats,
                                    batch_cotangenter=batch_cotangenter,
                                    per_sample_loss=per_sample_loss)

    if minibs is None:
        new_deps, new_dstate = minibatched_backward(batch, num_poison=num_poison)
        deps, dstate = add_trees((deps, dstate), (new_deps, new_dstate))
    else:
        deps, dstate = minibatch_func(minibatched_backward, batch, minibs, 
                                      acc=(deps, dstate))

    return deps, dstate

def make_forwards(batch_cotangenter, per_sample_loss, num_poison):
    get_grads, apply_grads = factored_functional_step(use_jit=False)
    # get_grads, apply_grads = factored_functional_step(use_jit=True)
    eps, get_grads, apply_grads = batch_cotangenter(per_sample_loss, get_grads, 
                                                    apply_grads, num_poison)
    return eps, get_grads, apply_grads

def take_vjp(fn, cotangents, *args):
    _, vjp_fn = jax.vjp(fn, *args)
    # lets make cotangents match the primal?
    return vjp_fn(cotangents)

@partial(maybe_jit, static_argnames=('num_poison',))
def backward_grad(batch, state, dgrads, dupdates, batch_cotangenter, per_sample_loss, num_poison=0):
    """
    Takes the jacobian of the function
        f(eps, state) -> grads, updates
    and applies it to the cotangents given by dgrads, dupdates 
    (these are given by backward_apply)
    """
    eps, get_grads, _ = make_forwards(batch_cotangenter, per_sample_loss, num_poison)

    def state_to_grads(eps, state):
        grads, updates = get_grads(eps=eps, state=state, batch=batch)
        return grads, updates['batch_stats']

    return take_vjp(state_to_grads, (dgrads, dupdates), eps, state)

@partial(maybe_jit, static_argnames=('num_poison',))
def backward_apply(state, next_state, dstate, batch_cotangenter,
                   per_sample_loss, num_poison=0):
    """
    Takes the jacobian of the function 
        f(eps, grads, updates, state) -> next_state
    and applies it to the cotangents given by dstate
    """
    eps, _, apply_grads = make_forwards(batch_cotangenter, per_sample_loss, num_poison)
    grads = state.infer_gradient_from(next_state)
    # updates = {'batch_stats': next_state.batch_stats}

    def state_to_state(eps, grads, state0):
        return apply_grads(eps=eps, 
                           grads=grads, 
                           state=state0, 
                           updates={'batch_stats': state0.batch_stats})

    dstate = state.replace(params=dstate.params, batch_stats=dstate.batch_stats, opt_state=dstate.opt_state)
    return take_vjp(state_to_state, dstate, eps, grads, state)

global garbage_state
garbage_state = None

def wash_state(state):
    global garbage_state
    if garbage_state is None:
        garbage_state = state
    return garbage_state.replace(params=state.params, opt_state=state.opt_state,
                                    batch_stats=state.batch_stats)

def one_eval_step(params, batch_stats, batch, per_sample_loss, tot, acc_loss, minibs):
    def losser(minibatch):
        # return jnp.sum(per_sample_loss(params, batch_stats, minibatch, train=False))
        return per_sample_loss(params, batch_stats, minibatch, train=False)

    if minibs is not None:
        this_loss = minibatch_func(losser, batch, minibs)
    else:
        this_loss = losser(batch)
    bs = len(batch[0])
    acc_loss = jnp.concatenate([acc_loss, this_loss * bs])
    return acc_loss

def eval_model(state, val_batch_maker, val_its, per_sample_loss, limited, minibs):
    if limited:
        val_its = min(50, val_its)

    acc_loss = jax.device_put(np.array([]), jax.devices('gpu')[0])
    tot = jax.device_put(np.array(0), jax.devices('gpu')[0])
    params = state.params
    for it in tqdm(range(val_its), desc='Evaluating model..'):
        batch = val_batch_maker(it)
        acc_loss = one_eval_step(params, state.batch_stats, batch, per_sample_loss, tot, acc_loss, minibs)

    return acc_loss

def batch_for_state(state, train_batch_maker):
    if isinstance(state, int):
        k = state
    else:
        k = int(state.opt_state.count)

    return train_batch_maker(k)

def _grads_for_batch(batch, statek, per_sample_loss):
    print('Compiling grads for batch with # poison = ', batch[1][0][1].shape)
    def losser(params):
        losses, updates = per_sample_loss(params=params, 
                                          batch_stats=statek.batch_stats, 
                                          batch=batch, 
                                          train=True)
        return jnp.sum(losses), updates

    (_, updates), grads = jax.value_and_grad(losser, has_aux=True)(statek.params)
    return grads, updates

grads_for_batch = maybe_jit(_grads_for_batch)

# Edited for batch norm
@maybe_jit
def apply_grads(statek, grads, updates, lr_factor, wd_factor):
    statek_int = statek.apply_grads(grads, lr_factor=lr_factor, wd_factor=wd_factor)
    statek_int = statek_int.replace(batch_stats=updates['batch_stats'])
    return statek_int

def minibatch_func(func, batch, minibs, *, acc=None):
    device = jax.devices('gpu')[0]
    bs = len(batch[0])
    bsi = jnp.arange(bs)

    if minibs is None:
        minibs = bs

    def get_minibatch(i):
        ixs, (x, y) = batch
        sel = slice(i, i + minibs)

        mb_example = jax.device_put((x[sel], y[sel]), device)
        mb_indices = ixs[sel]
        mb_bsi = jax.device_put(bsi[sel])

        minibatch = mb_indices, mb_example, mb_bsi
        return minibatch

    i = 0
    minibatch = get_minibatch(i)
    while i < bs:
        assert len(minibatch[0]) > 0
        next_i = i + minibs
        next_minibatch = get_minibatch(next_i) if next_i < bs else None
        grads, updates = func(minibatch)
        if updates is not None:
            raise ValueError('Minibatching is not supported with batchnorm')
        acc = grads if acc is None else add_trees(acc, grads)
        i, minibatch = next_i, next_minibatch

    return acc, None

def get(state, batch, per_sample_loss, this_grads_for_batch, use_jit, minibs):
    batch_to_grads = partial(this_grads_for_batch,
                                per_sample_loss=per_sample_loss, statek=state)
    num_in_batch = len(batch[1][0])
    if num_in_batch == 0:
        return jax.tree_util.tree_map(safe_zeros, state.params)

    if use_jit and minibs is not None:
        grads_acc, updates = minibatch_func(batch_to_grads, batch, minibs)
    else:
        grads_acc, updates = batch_to_grads(batch)

    return grads_acc, updates

def apply(state, grads_acc, updates, lr_factor, wd_factor):
    return apply_grads(state, grads_acc, updates, lr_factor, wd_factor)

def factored_functional_step(minibs=None, use_jit=True, lr_factor=get_one(),
                             wd_factor=get_one()):
    if use_jit:
        this_grads_for_batch = grads_for_batch
    else:
        this_grads_for_batch = _grads_for_batch
    
    p_get = jax.tree_util.Partial(get, 
                                  this_grads_for_batch=this_grads_for_batch, 
                                  use_jit=use_jit, 
                                  minibs=minibs)
    
    p_apply = jax.tree_util.Partial(apply,
                                    lr_factor=lr_factor,
                                    wd_factor=wd_factor)

    return p_get, p_apply

def functional_step(state, batch, per_sample_loss, lr_factor=get_one(),
                    wd_factor=get_one(), minibs=None, use_jit=True):
    fns = factored_functional_step(minibs, use_jit, lr_factor, wd_factor)
    get_grads, apply_grads = fns
    grads, updates = get_grads(state, batch, per_sample_loss)
    # return apply_grads(state, grads, updates, 
                    #    lr_factor=lr_factor, wd_factor=wd_factor)
    return apply_grads(state, grads, updates)

from .instruction_ds import HetSeqBatch

def mask_batch(ixs, x, y, leave_out_indices):
    mask = np.array([not (tix in leave_out_indices) for tix in ixs])
    if isinstance(x, HetSeqBatch):
        assert isinstance(y, HetSeqBatch)
        masked_batch = ixs[mask], (x.subselect(mask), y.subselect(mask))
    else:
        masked_batch = ixs[mask], (x[mask], y[mask])
    return masked_batch

