from functools import partial

import distrax
import jax
import jax.numpy as jnp
import chex
import optax
from flax.training import train_state

import algorithms.common.types as tp
from algorithms.common import markov_kernel
from algorithms.common.models.pisgrad_net import PISGRADNet
from algorithms.common.models.statetime_net import StateTimeNetwork
from algorithms.scld.is_weights import sub_traj_is_weights, update_samples_log_weights, get_lnz_elbo_increment
from algorithms.scld.loss_fns import get_loss_fn, lnZ_update_vanilla, lnZ_update_jensen, sub_traj_fwd_kl, sub_traj_rev_kl
from algorithms.scld.prioritised_buffer_traj import build_prioritised_subtraj_buffer
from algorithms.scld.prioritised_buffer import build_prioritised_buffer
from algorithms.scld.scld_eval import eval_scld
from algorithms.scld.scld_utils import GeometricAnnealingSchedule, print_results, gradient_step, flattened_traversal, pseudo_huber_loss, make_lr_scheduler, get_subtraj_weightscheme
from algorithms.scld.is_weights import per_sample_sub_traj_is_weight, per_subtraj_log_is
from algorithms.scld.is_weights import ode_log_prob, simulate_prob_flow_ode
from algorithms.scld import resampling
import wandb
import time

Array = tp.Array
FlowApply = tp.FlowApply
FlowParams = tp.FlowParams
LogDensityByStep = tp.LogDensityByStep
LogDensityNoStep = tp.LogDensityNoStep
MarkovKernelApply = tp.MarkovKernelApply
AcceptanceTuple = tp.AcceptanceTuple
RandomKey = tp.RandomKey
Samples = tp.Samples
assert_equal_shape = chex.assert_equal_shape
assert_trees_all_equal_shapes = chex.assert_trees_all_equal_shapes

def f_cosine(x):
    s = 0.008
    return jnp.cos((x+s)/(1+s) * jnp.pi/2)**2

def get_beta_nonlearnt(step, alg_cfg):
        if alg_cfg.annealing_schedule.schedule_type == "uniform":
            beta = step / alg_cfg.num_steps
            return beta
        elif alg_cfg.annealing_schedule.schedule_type == "cosine":
            # eq17 of https://arxiv.org/pdf/2102.09672
            return f_cosine(1-step / alg_cfg.num_steps)/f_cosine(0) 
        else:
            raise NotImplementedError

def get_beta_schedule(params, alg_cfg):
        if alg_cfg.annealing_schedule.schedule_type == "learnt":
            b = jax.nn.softplus(params['params']['betas'])
            b = jnp.cumsum(b) / jnp.sum(b)
            b = jnp.concatenate((jnp.array([0]), b))
            def get_beta(step):
                return b[jnp.array(step, int)]
            
            return get_beta
        else:
            return partial(get_beta_nonlearnt, alg_cfg = alg_cfg)

def get_annealing_fn_and_prior(params, alg_cfg, target_log_prob):
   
    initial_density = distrax.MultivariateNormalDiag(params['params']['prior_mean'], 
                                                     jnp.exp(params['params']['prior_log_stds']))

    fixed_annealing_schedule = GeometricAnnealingSchedule(initial_density.log_prob, 
                                                      target_log_prob,
                                                      alg_cfg.num_steps + 1,
                                                      alg_cfg.target_clip,
                                                      alg_cfg.annealing_schedule.schedule_type)
    
    beta_fn = get_beta_schedule(params, alg_cfg)
    
    def get_schedule(step, x):
        log_densities_final = fixed_annealing_schedule._final_log_density(x)
        log_densities_initial = fixed_annealing_schedule._initial_log_density(x)
        beta = beta_fn(step)
        return (1. - beta) * log_densities_initial + beta * log_densities_final
    
    noise_schedule = alg_cfg.noise_schedule(sigma_max = jnp.exp(params['params']['log_max_diffusion']))\
        if alg_cfg.learn_max_diffusion else alg_cfg.noise_schedule
        
    
    return get_schedule, initial_density, beta_fn, noise_schedule

def inner_step_simulate(key,
                        model_state,
                        params,
                        samples,
                        log_weights,
                        sim_tuple,
                        markov_kernel_apply,
                        sub_traj,
                        config,
                        smc_settings,
                        batchsize_override):
    key, key_gen = jax.random.split(key)
    keys = jax.random.split(key, samples.shape[0])
    log_is_weights, aux = sub_traj_is_weights(keys, samples, model_state, params,
                                              sim_tuple, sub_traj, stop_grad=config.loss != "rev_kl",
                                              detach_langevin_pisgrad=config.model.get("model_detach_langevin",True))
    model_samples, target_log_probs, model_subtrajectories = aux

    increments = get_lnz_elbo_increment(log_is_weights, log_weights)

    sub_traj_start_point, sub_traj_end_point, sub_traj_idx, sub_traj_length = sub_traj

    key, key_gen = jax.random.split(key_gen)

    use_resampling, use_mcmc = smc_settings

    (log_density_per_step, noise_schedule, total_steps, (langevin_norm_clip)) = sim_tuple
    
    next_samples, next_log_weights, acceptance_tuple, debug_tuple = update_samples_log_weights(
        samples=model_samples, log_is_weights=log_is_weights, markov_kernel_apply=markov_kernel_apply,
        log_weights=log_weights, step=sub_traj_end_point[0], key=key,
        use_reweighting=use_resampling, use_resampling=use_resampling, resampler=config.resampler,
        use_markov=use_mcmc, resample_threshold=config.resample_threshold, 
        log_density_per_step = log_density_per_step)

    next_samples = jax.lax.stop_gradient(next_samples)
    if config.use_markov:
        target_log_probs = log_density_per_step(sub_traj_end_point[0], next_samples)[:, None]

    if config.sweep_mode and (config.loss == "rev_kl" or batchsize_override or config.get("use_nobuffer_mode",False)):
        # memory saver mode, hacky
        return next_samples, next_log_weights, increments, target_log_probs, log_is_weights, None, debug_tuple
    else:
        return next_samples, next_log_weights, increments, target_log_probs, log_is_weights, model_subtrajectories, debug_tuple
def simulate(key_gen,
             model_state,
             params,
             target_log_density,
             markov_kernel_apply,
             traj,
             config,
             smc_settings,
             batchsize_override=0):
    batch_size = config.batch_size if batchsize_override == 0 else batchsize_override
    key, key_gen = jax.random.split(key_gen)
    log_density_per_step, initial_density, beta_fn, noise_schedule = get_annealing_fn_and_prior(params, config, target_log_density)
    initial_samples = initial_density.sample(seed=key, sample_shape=(batch_size,))
    # these deviate from 1/n iff resampling
    initial_log_weights = -jnp.log(batch_size) * jnp.ones(batch_size)
    markov_kernel_apply = markov_kernel_apply(log_density_per_step, beta_fn)
    (n_sub_traj, sub_traj_start_points, sub_traj_end_points, sub_traj_indices, sub_traj_length) = traj

    key, key_gen = jax.random.split(key_gen)
    sub_traj_keys = jax.random.split(key, n_sub_traj)
    sim_tuple = (log_density_per_step, noise_schedule, config.num_steps, (config.langevin_norm_clip,))

    # Define initial state and per step inputs for scan step
    initial_state = (initial_samples, initial_log_weights)
    per_step_inputs = (sub_traj_keys, sub_traj_start_points, sub_traj_end_points, sub_traj_indices)

    # the rollout (note log_weights and log_is_weights a.k.a rnds are different)
    # target_log_probs are the log probs of annealing schedule at the sample points
    def scan_step(state, per_step_input):
        samples, log_weights = state
        key, sub_traj_start_point, sub_traj_end_point, sub_traj_idx = per_step_input
        sub_traj = (sub_traj_start_point, sub_traj_end_point, sub_traj_idx, sub_traj_length)
        next_samples, next_log_weights, increments, target_log_probs, log_is_weights, subtrajectories, (pre_resample_logweights,)  = inner_step_simulate(key, model_state, params,
                                                                                           samples, log_weights,
                                                                                           sim_tuple,
                                                                                           markov_kernel_apply,
                                                                                           sub_traj, config,
                                                                                           smc_settings, batchsize_override)

        next_state = (next_samples, next_log_weights)
        per_step_output = (next_samples, increments, target_log_probs, log_is_weights, subtrajectories, 
                           (resampling.log_effective_sample_size(log_weights),  resampling.log_effective_sample_size(pre_resample_logweights), log_weights, pre_resample_logweights))
        return next_state, per_step_output

    # final_state contains final samples & rnds
    final_state, per_step_outputs = jax.lax.scan(scan_step, initial_state, per_step_inputs)
    
    # samples array shape (num_subtraj, batch_size, problem_dim)
    # when returned, also appends the initial samples (i.e time 0 sampless)
    # i.e samples array contains the positions of each trajectory at initial, final and subtraj splitpoint times
    # also returned is the corresponding log_probs as according to annealing schedule of each sample
    # this is a (num_subtraj, batch_size) dim vector
    samples, (lnz_incs, elbo_incs), sub_traj_target_log_probs, sub_traj_log_is_weights, sub_trajs, log_ess_per_subtraj = per_step_outputs
    lnz, elbo = jnp.sum(lnz_incs), jnp.sum(elbo_incs)

    return jnp.concatenate([jnp.expand_dims(initial_samples, 0), samples], axis=0), \
        jnp.concatenate([jnp.ones((1, batch_size)), sub_traj_target_log_probs[:, :, 0]], axis=0), \
        (lnz, elbo), (final_state, sub_traj_log_is_weights), sub_trajs, log_ess_per_subtraj


def sample_and_concat(key, buffer_samples, new_samples, buffer_on_cpu = False, subtraj_id = None):
    # remark: this can be made more sample-efficient by taking all new_samples
    # and matching them w/ equal number of buffer_samples
    # this also moves everything to gpu if necessary

    half_N = buffer_samples.shape[1]
    
    # Sample N/2 random indices from new_samples
    random_indices = jax.random.choice(key, 2*half_N, shape=(half_N,), replace=False)
    
    # Gather the sampled subset
    # wish to detach gradients from samples. 
    sampled_subset = jax.lax.stop_gradient(new_samples[:,random_indices,:]) 
    
    if subtraj_id is not None:
        # train on one subtraj only
        sampled_subset = sampled_subset[subtraj_id]
    
    # Concatenate the sampled subset with new_samples

    if buffer_on_cpu:
        buffer_samples = jax.device_put(buffer_samples, device=jax.devices('gpu')[0])
    
    combined_samples = jnp.concatenate([buffer_samples, sampled_subset], axis=1)
    
    return combined_samples

def make_subtrajectory_boundaries(num_steps, n_sub_traj):
    # Compute boundaries of sub-trajectories
    #sub_traj_length = alg_cfg.num_steps // alg_cfg.n_sub_traj
    #n_sub_traj = alg_cfg.n_sub_traj
    sub_traj_length = num_steps // n_sub_traj
    sub_traj_start_points = jnp.array([[t * sub_traj_length] for t in range(n_sub_traj)])
    sub_traj_end_points = jnp.array([[(t + 1) * sub_traj_length] for t in range(n_sub_traj)])
    sub_traj_indices = jnp.arange(n_sub_traj)
    traj = (n_sub_traj, sub_traj_start_points, sub_traj_end_points, sub_traj_indices, sub_traj_length)

    return sub_traj_length, sub_traj_start_points, sub_traj_end_points, sub_traj_indices, traj 

def scld_trainer(cfg, target):
    # Initialization
    
    key_gen = jax.random.PRNGKey(cfg.seed)
    dim = target.dim
    alg_cfg = cfg.algorithm

    target_samples = target.sample(seed=jax.random.PRNGKey(cfg.seed), sample_shape=(cfg.eval_samples,))
    do_single_subtraj_training = False if not hasattr(alg_cfg, "memory_saver") else alg_cfg.memory_saver.subtraj_scheme != "all"


    n_sub_traj, num_transitions = alg_cfg.n_sub_traj, alg_cfg.num_steps + 1  # todo check if this is correct
    sub_traj_length, sub_traj_start_points, sub_traj_end_points, sub_traj_indices, traj = make_subtrajectory_boundaries(alg_cfg.num_steps, alg_cfg.n_sub_traj)
    
    traj_inference = traj 
    
    BENCHMARK_THRESHOLD = 1 # 100 if alg_cfg.get("time_everything", False) else 1 

    if hasattr(alg_cfg, "n_sub_traj_inference"):
        _, _, _, _, traj_inference = make_subtrajectory_boundaries(alg_cfg.num_steps, alg_cfg.n_sub_traj_inference)
    
    # make the buffer
    buffer_on_cpu = False if not hasattr(alg_cfg, "memory_saver") else alg_cfg.memory_saver.buffer_on_cpu
    cpu_device = jax.devices('cpu')[0]
    buffer = build_prioritised_subtraj_buffer(dim, alg_cfg.n_sub_traj,
                                        jnp.array(alg_cfg.buffer.max_length_in_batches * alg_cfg.batch_size, dtype=int),
                                        jnp.array(alg_cfg.buffer.min_length_in_batches * alg_cfg.batch_size, dtype=int),
                                        sub_traj_length+1, # length of one subtraj (including startpoint)  
                                        sample_with_replacement=alg_cfg.buffer.sample_with_replacement,
                                        prioritized=alg_cfg.buffer.prioritized, temperature=alg_cfg.buffer.temperature if hasattr(alg_cfg.buffer, "temperature") else 1,
                                        on_cpu=buffer_on_cpu)
    
    # Define the model
    # TODO do something less hacky
    model = None
    if hasattr(alg_cfg.model, "inner_clip"):
        model = PISGRADNet(**alg_cfg.model)
    else:
        model = StateTimeNetwork(**alg_cfg.model)
    
    key, key_gen = jax.random.split(key_gen)
    params = model.init(key, jnp.ones([alg_cfg.batch_size, dim]),
                        jnp.ones([alg_cfg.batch_size, 1]),
                        jnp.ones([alg_cfg.batch_size, dim]))

    additional_params = {}
    if alg_cfg.loss in ['rev_tb', 'fwd_tb']:
        # include global baseline for second moment (i.e tb) loss
        # to adapt to subtrajectory setting, each subtrajectory tracks a delta_lnZ
        # where delta_lnZ tracks ln Z_{t+1}- lnZ_t = ln E_forward[dP_{back, t+1}/dP_{forward, t+1}]
        initial_logZ = alg_cfg.init_logZ if not alg_cfg.leak_true_lnZ else alg_cfg.true_lnZ
        additional_params['logZ_deltas'] = jnp.ones(n_sub_traj) * initial_logZ
    elif alg_cfg.loss in ['rev_kl','fwd_kl']:
        assert(alg_cfg.buffer.max_length_in_batches == 1)
        #assert(alg_cfg.buffer.use_subtraj_buffer == False) Always use subtraj buffer now, deprecated
        assert(alg_cfg.buffer.sampling_scheme == "vanilla")
    
    params_to_freeze = []
    if alg_cfg.annealing_schedule.schedule_type == "learnt":
        additional_params['betas'] = jnp.ones((alg_cfg.num_steps,))
    else:
        params_to_freeze.append('betas')

    additional_params['prior_log_stds'] = jnp.ones((dim,)) * jnp.log(alg_cfg.init_std)
    additional_params['prior_mean'] = jnp.zeros((dim,))


    learnt_prior_params = []
    prior_lr = 0 if not hasattr(alg_cfg,"prior") else alg_cfg.prior.lr
    if hasattr(alg_cfg,"prior"):
        if alg_cfg.prior.learn_variance:
            learnt_prior_params.append('prior_log_stds')
        else:
            params_to_freeze.append('prior_log_stds')

        if alg_cfg.prior.learn_mean:
            learnt_prior_params.append('prior_mean')
        else:
            params_to_freeze.append('prior_mean')
    else:
        params_to_freeze = ['prior_mean', 'prior_log_stds']
    

    additional_params['log_max_diffusion'] = jnp.log(alg_cfg.max_diffusion)

    if alg_cfg.learn_max_diffusion:
        # check that alg_cfg.noise_schedule is a factory
        assert(callable(alg_cfg.noise_schedule(sigma_max = 0))) 
        learnt_prior_params.append('log_max_diffusion')
    else:
        # check that alg_cfg.noise_schedule is a noise_schedule
        assert(not callable(alg_cfg.noise_schedule(0))) 
        params_to_freeze.append('log_max_diffusion')

    params['params'] = {**params['params'], **additional_params}

    lr_scheduler = make_lr_scheduler(alg_cfg)

    # if betas are learnt, apply same optimizer
    optimizer = optax.chain(
        optax.clip(alg_cfg.grad_clip) if alg_cfg.grad_clip > 0 else optax.identity(), # clip gradients to 1 (it can be a lot bigger)
        optax.masked(optax.adam(learning_rate=lr_scheduler),
                     mask=flattened_traversal(lambda path, _: path[-1] not in ['logZ_deltas','betas', 'prior_log_stds', 'prior_mean'])),
        
        optax.masked(optax.adam(learning_rate=alg_cfg.annealing_schedule.schedule_lr), 
                     mask=flattened_traversal(lambda path, _: path[-1] in ['betas'])),
        
        optax.masked(optax.adam(learning_rate=prior_lr),
                     mask=flattened_traversal(lambda path, _: path[-1] in learnt_prior_params)),

        optax.masked(optax.set_to_zero(), mask=flattened_traversal(lambda path, _: path[-1] in params_to_freeze)), 
        # we end up doing logZ learning manually
        #optax.masked(optax.sgd(learning_rate=alg_cfg.logZ_step_size),
        #             mask=flattened_traversal(lambda path, _: path[-1] == 'logZ')) if alg_cfg.loss in ['rev_tb',
        #                                                                                               'fwd_tb'] else optax.identity(),
    )

    # add gradient accumulation
    if hasattr(alg_cfg,"memory_saver"):
        # take gradients steps at fixed intervals
        # note to have the same training result n_sims must be scaled accordingly
        optimizer = optax.MultiSteps(optimizer, every_k_schedule=alg_cfg.memory_saver.accumulate_gradients)

    elif hasattr(alg_cfg,"gradient_accumulation_steps"):
        # take gradients steps at fixed intervals
        # note to have the same training result n_sims must be scaled accordingly
        optimizer = optax.MultiSteps(optimizer, every_k_schedule=alg_cfg.gradient_accumulation_steps)

    model_state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    """
    initial_density = distrax.MultivariateNormalDiag(jnp.zeros(dim),
                                                     jnp.ones(dim) * alg_cfg.init_std)
    
    log_density_per_step = GeometricAnnealingSchedule(initial_density.log_prob, target.log_prob,
-                                                      num_transitions, alg_cfg.target_clip,
-                                                      alg_cfg.annealing_schedule.schedule_type)
    """
    if hasattr(alg_cfg, "use_beta_adapted_mcmc") and alg_cfg.use_beta_adapted_mcmc:
        markov_kernel_by_step = lambda annealing_fn, beta_fn: markov_kernel.MarkovTransitionKernel(alg_cfg.mcmc, annealing_fn, num_transitions,
                                                                                           fixed_linear_beta_schedule=False, beta_schedule=beta_fn)
    else:
        markov_kernel_by_step = lambda annealing_fn, _: markov_kernel.MarkovTransitionKernel(alg_cfg.mcmc, annealing_fn, num_transitions,
                                                                                           fixed_linear_beta_schedule=True)
    # curried, aliased and jitted versions of the simulate function
    simulate_short_train_nojit = partial(simulate, 
                                           target_log_density = target.log_prob,
                                           markov_kernel_apply=markov_kernel_by_step,
                                           traj=traj, config=alg_cfg, batchsize_override = 0, smc_settings = (alg_cfg.use_resampling, alg_cfg.use_markov))
    simulate_short_train = jax.jit(simulate_short_train_nojit, static_argnames=('config', 'smc_settings'))

    simulate_short_no_smc = jax.jit(partial(simulate,
                                            target_log_density = target.log_prob,
                                            markov_kernel_apply=markov_kernel_by_step, traj=traj,
                                            config=alg_cfg, batchsize_override=cfg.eval_samples, smc_settings = (False, False)),
                                      static_argnames=('config'))
    
    simulate_fn_list = []

    if alg_cfg.get("sim_all_lengths",False):
        for i in range(8):
            _, _, _, _, traj_inference_i = make_subtrajectory_boundaries(alg_cfg.num_steps, 2**i)
            simulate_fn_list.append((2**i, jax.jit(partial(simulate,
                                                target_log_density = target.log_prob,
                                                markov_kernel_apply=markov_kernel_by_step, traj=traj_inference_i,
                                                config=alg_cfg, batchsize_override=cfg.eval_samples, smc_settings = (alg_cfg.use_resampling_inference, alg_cfg.use_markov_inference)))))
    elif alg_cfg.get("ablate_markov",False):
        simulate_fn_list = [("no_markov", jax.jit(partial(simulate,
                                                target_log_density = target.log_prob,
                                                markov_kernel_apply=markov_kernel_by_step, traj=traj_inference,
                                                config=alg_cfg, batchsize_override=cfg.eval_samples, smc_settings = (alg_cfg.use_resampling_inference, False)),
                                        static_argnames=('config'))),
                            ("markov", jax.jit(partial(simulate,
                                                target_log_density = target.log_prob,
                                                markov_kernel_apply=markov_kernel_by_step, traj=traj_inference,
                                                config=alg_cfg, batchsize_override=cfg.eval_samples, smc_settings = (alg_cfg.use_resampling_inference, True)),
                                        static_argnames=('config')))]
    else:
        simulate_short_smc = jax.jit(partial(simulate,
                                                target_log_density = target.log_prob,
                                                markov_kernel_apply=markov_kernel_by_step, traj=traj_inference,
                                                config=alg_cfg, batchsize_override=cfg.eval_samples, smc_settings = (alg_cfg.use_resampling_inference, alg_cfg.use_markov_inference)),
                                        static_argnames=('config'))
        simulate_fn_list = [(alg_cfg.get("n_sub_traj_inference", alg_cfg.n_sub_traj), simulate_short_smc)]

    lnz_upd_fn = partial(lnZ_update_jensen, lr=alg_cfg.logZ_step_size, batch_size = alg_cfg.batch_size)\
              if alg_cfg.use_jensen_trick else\
                  partial(lnZ_update_vanilla,lr=alg_cfg.logZ_step_size)
    
    sub_traj_loss = get_loss_fn(alg_cfg.loss)
    
    if alg_cfg.loss in ['rev_tb']:
        if alg_cfg.use_pseudohuber:
            f_fn = partial(pseudo_huber_loss,\
                                delta=alg_cfg.pseudo_huber_delta)
            sub_traj_loss = partial(sub_traj_loss, f_fn = f_fn)
    
    
    # Make the prob_flow_ode eval function
    def get_probflow_ode_data(key, model_state, params, target_has_samples = True):

        key1, key2 = jax.random.split(key)

        annealing_fn, initial_density, _, noise_schedule = get_annealing_fn_and_prior(params, alg_cfg, target.log_prob)
        sim_tuple = (annealing_fn, noise_schedule, num_transitions, (alg_cfg.langevin_norm_clip,))    
        
        x_ode, logp_odes = simulate_prob_flow_ode(key1, alg_cfg.batch_size,
                                   initial_density, model_state,
                                   params, sim_tuple)
        if target_has_samples:
            target_sample_logps = ode_log_prob(key2, target_samples, model_state,
                                    params, sim_tuple)
        else:
            target_sample_logps = -jnp.inf
        
        return x_ode, logp_odes, target_sample_logps
        
    probflow_ode_fn = jax.jit(partial(get_probflow_ode_data, target_has_samples = \
                                      (target_samples is not None)))

    def data_loss_simple_kl(keys, samples, model_state, params):
        annealing_fn, initial_density, _, noise_schedule = get_annealing_fn_and_prior(params, alg_cfg, target.log_prob)
        sim_tuple = (annealing_fn, noise_schedule, num_transitions, (alg_cfg.langevin_norm_clip,))

        sub_traj_start_points_simple = jnp.array([0])
        sub_traj_end_points_simple = jnp.array([alg_cfg.num_steps])
        sub_traj_indices_simple = 0
        
        return sub_traj_fwd_kl(keys, None,samples , model_state, params, sim_tuple,
                             sub_traj_start_points_simple, sub_traj_end_points_simple, sub_traj_indices_simple, sub_traj_length,
                             per_sample_rnd_fn=per_sample_sub_traj_is_weight)
    

    if hasattr(alg_cfg, "data"):
        # do pretraining on data
        assert alg_cfg.data.n_samples % alg_cfg.data.batch_size == 0, "Batch size must divide the number of rows exactly."
        train_samples = target.sample(seed=jax.random.PRNGKey(cfg.seed + 100), sample_shape=(alg_cfg.data.n_samples,))
        # Calculate the number of batches
        num_batches = alg_cfg.data.n_samples // alg_cfg.data.batch_size
        batches = jnp.reshape(train_samples, (num_batches, alg_cfg.data.batch_size, train_samples.shape[-1]))
        key_gen_data, key_gen = jax.random.split(key_gen)
        
        loss_fn_data = jax.value_and_grad(jax.jit(data_loss_simple_kl), 3, has_aux=True)

        n_batches_trained = 0
        for i in range(alg_cfg.data.epochs):      
            for b in range(num_batches):
                key, key_gen_data = jax.random.split(key_gen_data)
                keys = jax.random.split(key, (alg_cfg.data.batch_size,))
                n_batches_trained += 1
                (per_sample_loss, aux), grads_all = loss_fn_data(keys, batches[b], model_state, model_state.params)
                    
                model_state = gradient_step(model_state, grads_all)

                if n_batches_trained % 10 == 0 or n_batches_trained == 1:
                    print(f"#batches = {n_batches_trained}: loss = {per_sample_loss.mean()}")

    key, key_gen = jax.random.split(key_gen)

    get_schedule_and_prior_fn = partial(get_annealing_fn_and_prior, alg_cfg = alg_cfg, target_log_prob = target.log_prob)
    eval_fn = eval_scld(simulate_short_no_smc, simulate_fn_list, probflow_ode_fn, get_schedule_and_prior_fn ,
                        target, target_samples, cfg)

    logger = {}
    eval_freq = alg_cfg.n_sim * alg_cfg.n_updates_per_sim // cfg.n_evals

    if eval_freq == 0:
        eval_freq = alg_cfg.n_sim - 1

    times = {
        'simulation_time': [],
        'buffer_storage_time': [], # time to store simulated batch
        'buffer_access_time': [], # time to assemble the training batch and put it on GPU
        'loss_computation_time': [],
        'weight_update_time': [],
        'move_trajs_gpu_to_cpu_time': []
    }

    moving_averages = {
        'loss': []
    }

    start_time = time.time()

    if alg_cfg.loss == "rev_kl":

        def kl_loss(key, model_state, params):

            sim_samples, sub_traj_target_log_probs, (lnz_est, elbo_est), ((final_samples, final_weights), sub_traj_logrnds), sub_trajs, _ = simulate_short_train_nojit(key, model_state, params) # final_weights is not log(1/n) iff resampling
            return -elbo_est, None 
        
        def kl_loss_one_subtraj(key, model_state, params):
            key1, key2 = jax.random.split(key)
            annealing_fn, initial_density, _, noise_schedule = get_annealing_fn_and_prior(params, alg_cfg, target.log_prob)
            initial_samples = initial_density.sample(seed=key1, sample_shape=(alg_cfg.batch_size,))
            sim_tuple = (annealing_fn, noise_schedule, num_transitions, (alg_cfg.langevin_norm_clip,))    
            return sub_traj_rev_kl(jax.random.split(key2, alg_cfg.batch_size), initial_samples, None, model_state, params, sim_tuple,
                                sub_traj_start_points[0], sub_traj_end_points[0], sub_traj_indices[0], sub_traj_length)

        kl_loss_jitted = jax.jit(jax.value_and_grad(kl_loss, 2, has_aux=True)) if alg_cfg.n_sub_traj > 1 else\
                         jax.jit(jax.value_and_grad(kl_loss_one_subtraj, 2, has_aux=True))

        time_elapsed_benchmarking = 0
        for i in range(alg_cfg.n_sim):
            key, key_gen = jax.random.split(key_gen)

            time_stamp1 = time.time()
            (neg_elbo, aux), grads =  kl_loss_jitted(key, model_state, model_state.params)
            # print(f'lnz {lnz_est}, elbo {elbo_est}')
            model_state = model_state.apply_gradients(grads = grads) 

            if i>= BENCHMARK_THRESHOLD:
                time_elapsed_benchmarking += time.time() - time_stamp1
            if cfg.use_wandb:
                    # loss_hist plots loss for each subtrajectory section
                    # loss plots overall loss
                    moving_averages['loss'].append(neg_elbo)
                    wandb.log({
                             #'loss_hist': per_sample_loss, TODO add per subtraj loss breakdown
                            'stats/n_inner_its': i,
                            'stats/n_sims': i, 
                            'loss': neg_elbo, 
                            'times/time_elapsed_secs': time.time() - start_time,
                            'time_benchmarking':  time_elapsed_benchmarking})
                    
            if i % eval_freq == 0 or i+1 == alg_cfg.n_sim:
                # target.visualise(buffer_samples[10], show=True)
                key, key_gen = jax.random.split(key_gen)
                logger.update(eval_fn(model_state, model_state.params, key, i))
                logger["stats/step"] = i
                print_results(i, logger, cfg)
                print(f'Loss: {neg_elbo}')

                if cfg.use_wandb:
                    wandb.log(logger)
    elif alg_cfg.buffer.max_length_in_batches == 1 and alg_cfg.get("use_nobuffer_mode",False):
        def lv_loss(key, model_state, params):
            sim_samples, sub_traj_target_log_probs, (lnz_est, elbo_est), ((final_samples, final_weights), sub_traj_logrnds), sub_trajs, _ = simulate_short_train_nojit(key, model_state, params) # final_weights is not log(1/n) iff resampling
            return jnp.clip(sub_traj_logrnds.var(ddof=0, axis=1), -1e7, 1e7).mean(), None
        
        lv_loss_jitted = jax.jit(jax.value_and_grad(lv_loss, 2, has_aux=True))

        time_elapsed_benchmarking = 0
        for i in range(alg_cfg.n_sim):
            key, key_gen = jax.random.split(key_gen)

            time_stamp1 = time.time()
            (loss, aux), grads =  lv_loss_jitted(key, model_state, model_state.params)
            # print(f'lnz {lnz_est}, elbo {elbo_est}')
            model_state = model_state.apply_gradients(grads = grads) 

            if i>= BENCHMARK_THRESHOLD:
                time_elapsed_benchmarking += time.time() - time_stamp1
            if cfg.use_wandb:
                    # loss_hist plots loss for each subtrajectory section
                    # loss plots overall loss
                    moving_averages['loss'].append(loss)
                    wandb.log({
                             #'loss_hist': per_sample_loss, TODO add per subtraj loss breakdown
                            'stats/n_inner_its': i,
                            'stats/n_sims': i, 
                            'loss': loss, 
                            'times/time_elapsed_secs': time.time() - start_time,
                            'time_benchmarking':  time_elapsed_benchmarking})
                    
            if i % eval_freq == 0 or i+1 == alg_cfg.n_sim:
                # target.visualise(buffer_samples[10], show=True)
                key, key_gen = jax.random.split(key_gen)
                logger.update(eval_fn(model_state, model_state.params, key, i))
                logger["stats/step"] = i
                print_results(i, logger, cfg)
                print(f'Loss: {loss}')

                if cfg.use_wandb:
                    wandb.log(logger)
    else:
        # Make the Loss fn
        # Rmk: References to outside variables okay if they are constants
        assert(alg_cfg.buffer.use_subtraj_buffer)
        key, key_gen = jax.random.split(key_gen)

        # if cfg.use_wandb:
        #     wandb.log(eval_fn(model_state, model_state.params, key))
        # return
        init_samples, sub_traj_target_log_probs, _, (_, log_rnds), inital_subtrajs, _ = simulate_short_train(key, model_state, params)
        buffer_state = buffer.init(inital_subtrajs, log_rnds)

        per_point_rnd_fn = per_subtraj_log_is 
        def sub_traj_loss_short(keys, samples, next_samples, model_state, params, sub_traj_start_points,
                                sub_traj_end_points, sub_traj_indices, subtraj_choice = None):
            annealing_fn, initial_density, _, noise_schedule = get_annealing_fn_and_prior(params, alg_cfg, target.log_prob)
            sim_tuple = (annealing_fn, noise_schedule, num_transitions, (alg_cfg.langevin_norm_clip,))    
            return sub_traj_loss(keys, samples, next_samples, model_state, params, sim_tuple,
                                sub_traj_start_points, sub_traj_end_points, sub_traj_indices, sub_traj_length,
                                per_sample_rnd_fn=per_point_rnd_fn, detach_langevin_pisgrad = alg_cfg.get("model_detach_langevin",True))
        
        loss_fn_tmp = jax.vmap(jax.value_and_grad(jax.jit(sub_traj_loss_short), 4, has_aux=True),
                        in_axes=(0, 0, 0, None, None, 0, 0, 0, None))
    
        def train_on_single_subtrajectory(keys, samples, next_samples, model_state, params, sub_traj_start_points,
                                sub_traj_end_points, sub_traj_indices, subtraj_choice):
                # If we are on do_single_subtraj_training mode
                # Input: samples, next_samples are (1, B, L, D) tensors

                return loss_fn_tmp(
                    keys[subtraj_choice], samples, next_samples, 
                    model_state, params, sub_traj_start_points[subtraj_choice],
                    sub_traj_end_points[subtraj_choice], sub_traj_indices[subtraj_choice])

        loss_fn = train_on_single_subtrajectory if do_single_subtraj_training else loss_fn_tmp
        
        time_elapsed_benchmarking = 0
        for i in range(alg_cfg.n_sim):
            key, key_gen = jax.random.split(key_gen)

            time_stamp1 = time.time()
            sim_samples, sub_traj_target_log_probs, (lnz_est, elbo_est), ((final_samples, final_weights), sub_traj_logrnds), sub_trajs, _ = simulate_short_train(key, model_state,
                                                                                        model_state.params) # final_weights is not log(1/n) iff resampling
            # print(f'lnz {lnz_est}, elbo {elbo_est}')
            time_stamp2 = time.time()

            xs = sub_trajs if not buffer_on_cpu else jax.device_put(sub_trajs, cpu_device)
            logws = sub_traj_logrnds if not buffer_on_cpu else jax.device_put(sub_traj_logrnds, cpu_device)

            time_stamp2a = time.time()
            buffer_state = buffer.add(xs, logws, buffer_state=buffer_state)
            time_stamp3 = time.time()
            if alg_cfg.loss in ['rev_tb']:
                if not alg_cfg.leak_true_lnZ:
                    # take advantage of this otherwise wasted data
                    # to update lnZ 
                    new_lnZs = jnp.array([lnz_upd_fn(model_state.params['params']['logZ_deltas'][subtraj_id],\
                                        sub_traj_logrnds[subtraj_id]) for subtraj_id in range(n_sub_traj)])
                    model_state.params['params']['logZ_deltas'] = new_lnZs      
            
            times['buffer_storage_time'].append(time_stamp3 - time_stamp2a)
            times['simulation_time'].append(time_stamp2 - time_stamp1)
            times['move_trajs_gpu_to_cpu_time'].append(time_stamp2a - time_stamp2)
            
            if i >= BENCHMARK_THRESHOLD:
                time_elapsed_benchmarking += time_stamp3 - time_stamp1

            for j in range(alg_cfg.n_updates_per_sim):

                key, key_gen = jax.random.split(key_gen)

                buffer_indices = None
                train_batch = None

                subtraj_to_train_on = None

                time_stamp4 = time.time()
                if do_single_subtraj_training: 
                    key, key_gen = jax.random.split(key_gen)
                    if alg_cfg.memory_saver.subtraj_scheme == "random":
                        subtraj_to_train_on = jax.random.randint(key, (1,),0,alg_cfg.n_sub_traj)
                    elif alg_cfg.memory_saver.subtraj_scheme == "cyclic":
                        subtraj_to_train_on = jnp.array([(i*alg_cfg.n_updates_per_sim + j) % alg_cfg.n_sub_traj])
                    else:
                        raise ValueError
                
                # TODO: some sort of preloading on GPU? i.e rather than ship one batch at a time,
                # we preload onto GPU
                if alg_cfg.buffer.sampling_scheme == "vanilla":
                    train_batch, buffer_indices = buffer.sample(key=key, buffer_state=buffer_state, batch_size=alg_cfg.batch_size, subtraj_id = subtraj_to_train_on)
                    if buffer_on_cpu:
                        train_batch = jax.device_put(train_batch, jax.devices("gpu")[0])
                elif alg_cfg.buffer.sampling_scheme == "New":
                    # for a new application always generate a new key
                    # require batch_size even as want 2* (N//2) = N
                    key2, key_gen = jax.random.split(key_gen) 

                    old_batch, buffer_indices = buffer.sample(key=key, buffer_state=buffer_state, batch_size=alg_cfg.batch_size//2, subtraj_id = subtraj_to_train_on)
                    train_batch = sample_and_concat(key2, old_batch, sub_trajs, buffer_on_cpu, subtraj_id = subtraj_to_train_on)
                time_stamp5 = time.time()
                key, key_gen = jax.random.split(key_gen)
                keys = jax.random.split(key, (n_sub_traj, alg_cfg.batch_size,))

                (per_sample_loss, (recomputed_logws, _)), grads_all = loss_fn(keys, train_batch,  train_batch, model_state,
                                                                model_state.params,
                                                                sub_traj_start_points, sub_traj_end_points, sub_traj_indices,  subtraj_to_train_on)
                time_stamp6 = time.time()

                if alg_cfg.buffer.update_weights:
                    # update weights

                    # recomputed_logws is (S, B, 1)-tensor, S = #subtraj, B = batch_size, if not conserving memory
                    # and (1, B, 1) tensor IF conserving memory
                    # The indices corresponding to particles from the buffer are the ones numbered 0 ... buffer_indices.shape[1]
                    # We extract those new_rnds into logw_update as a (1 or S, B) tensor and use it to update
                    logw_update = recomputed_logws[:, :buffer_indices.shape[1],0]
                    if buffer_on_cpu:
                        logw_update = jax.device_put(logw_update, cpu_device)
                        
                    if subtraj_to_train_on is None:
                        # trained on all subtrajs        
                        buffer_state = buffer.upd_weights(logw_update, buffer_indices, buffer_state)
                    else:
                        buffer_state = buffer.upd_weights(logw_update, buffer_indices, buffer_state, subtraj_to_train_on[0].item())
                time_stamp7 = time.time()    

                times['buffer_access_time'].append(time_stamp5-time_stamp4)
                times['loss_computation_time'].append(time_stamp6-time_stamp5)
                times['weight_update_time'].append(time_stamp7-time_stamp6)

                weights = get_subtraj_weightscheme(per_sample_loss, alg_cfg)
                model_state = gradient_step(model_state, grads_all, weights)
                
                if i >= BENCHMARK_THRESHOLD:
                    time_elapsed_benchmarking += time.time() - time_stamp4

                if cfg.use_wandb:
                    # loss_hist plots loss for each subtrajectory section
                    # loss plots overall loss
                    moving_averages['loss'].append(jnp.mean(per_sample_loss))
                    wandb.log({'loss_hist': per_sample_loss,
                            'stats/n_inner_its': alg_cfg.n_updates_per_sim * i + j,
                            'stats/n_sims': i, 
                            'loss': jnp.mean(per_sample_loss), 
                            'times/time_elapsed_secs': time.time() - start_time,
                            'time_benchmarking':  time_elapsed_benchmarking})

                    if not alg_cfg.sweep_mode:
                        times_upload_dict = {}
                        for (key, times_list) in times.items():
                            times_upload_dict[f'times/{key}_moving_average_secs'] = sum(times_list[-100:])/len(times_list[-100:])
                        
                        for (key, list) in moving_averages.items():
                            times_upload_dict[f'MovingAverages/{key}_moving_average_secs'] = sum(list[-20:])/len(times_list[-20:])
                        
                        wandb.log(times_upload_dict)

            if i % eval_freq == 0 or i+1 == alg_cfg.n_sim:
                # target.visualise(buffer_samples[10], show=True)
                key, key_gen = jax.random.split(key_gen)
                logger.update(eval_fn(model_state, model_state.params, key, i))
                logger["stats/step"] = i

                print_results(i, logger, cfg)
                print(f'Loss: {jnp.mean(per_sample_loss)}')

                if cfg.use_wandb:
                    wandb.log(logger)
    
    if hasattr(alg_cfg, "time_everything") and alg_cfg.time_everything:
        model_eval_times = []

        # burnin
        simulate_short_no_smc(jax.random.PRNGKey(0), model_state, params)
        simulate_short_smc(jax.random.PRNGKey(0), model_state, params)

        for i in range(20):
            timestamp = time.time()
            simulate_short_no_smc(jax.random.PRNGKey(i), model_state, params)
            model_eval_times.append(time.time() - timestamp)

        smc_eval_times = []
        for i in range(20):
            timestamp = time.time()
            simulate_short_smc(jax.random.PRNGKey(i), model_state, params)
            smc_eval_times.append(time.time() - timestamp)
        
        wandb.log({'avg_evaluation_time_model': jnp.array(model_eval_times).mean(),
                   'avg_evaluation_time_smc': jnp.array(smc_eval_times).mean(),
                   'std_evaluation_time_model': jnp.array(model_eval_times).std(),
                   'std_evaluation_time_smc': jnp.array(smc_eval_times).std()})