from jax.config import config; config.update("jax_enable_x64", True)
from jax import numpy as np
from jax import random
import optax
from tqdm import tqdm


import jax
import tensorflow as tf

from tensorflow_probability.substrates import jax as tfp
distributions = tfp.distributions


class _SineData:
    """
    Dataset of functions f(x) = a * sin(x - b) where a and b are randomly
    sampled. The function is evaluated from -pi to pi.
    Parameters
    ----------
    amplitude_range : tuple of float
        Defines the range from which the amplitude (i.e. a) of the sine function
        is sampled.
    shift_range : tuple of float
        Defines the range from which the shift (i.e. b) of the sine function is
        sampled.
    num_samples : int
        Number of samples of the function contained in dataset.
    num_points : int
        Number of points at which to evaluate f(x) for x in [-pi, pi].
    """
    def __init__(self, amplitude_range=(-1., 1.), shift_range=(-.5, .5),
                 num_samples=1000, num_points=100):
        self.amplitude_range = amplitude_range
        self.shift_range = shift_range
        self.num_samples = num_samples
        self.num_points = num_points
        self.x_dim = 1  # x and y dim are fixed for this dataset.
        self.y_dim = 1
        # Generate data
        self.data = []
        a_min, a_max = amplitude_range
        b_min, b_max = shift_range
        rng = random.PRNGKey(0)
        for i in range(num_samples):
            rng, uniform_rng = random.split(rng, 2)
            # Sample random amplitude
            a = (a_max - a_min) * random.uniform(uniform_rng) + a_min
            # Sample random shift
            rng, uniform_rng2 = random.split(rng, 2)
            b = (b_max - b_min) * random.uniform(uniform_rng2) + b_min
            # Shape (num_points, x_dim)
            x = np.linspace(-np.pi, np.pi, num_points)[..., None] # note this is also evenly spaced points
            # Shape (num_points, y_dim)
            y = a * np.sin(x - b)
            self.data.append([x, y])
    def get_data(self):
        ndarray_data = np.asarray(self.data)
        return tf.data.Dataset.from_tensor_slices(ndarray_data), ndarray_data 



def context_target_mask_gen(rng, batch_size, num_points, num_context, num_extra_target, use_y0=True):
    # Sample locations of context and target points
    points = np.arange(num_points)
    size = num_context + num_extra_target
    initial_loc = np.zeros(shape=(batch_size, 0))
    if use_y0:
        points = points[1:]
        initial_loc = np.zeros(shape=(batch_size, 1))
        size -=1 # 
    def batch_random_choice(key, num_items):
      return random.choice(key , num_items, (size,), replace=False)
    key_array = random.split(rng, batch_size)
    locations = jax.vmap(batch_random_choice, in_axes=(0, None))(key_array, points)
    locations = np.concatenate([initial_loc, locations], axis=-1)
    context_mask = jax.vmap(np.isin, in_axes=(None, 0))(np.arange(num_points) , locations[..., :num_context])
    target_mask = jax.vmap(np.isin, in_axes=(None, 0))(np.arange(num_points) , locations)
    return context_mask, target_mask


def test_NeuralProcess_output_the_same_if_using_mask():
    # prepare the data
    dataset, _ = _SineData(amplitude_range=(-1., 1.), shift_range=(-.5, .5), num_samples=20).get_data() # K=2000 in algorithm 1
    dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True).batch(batch_size=5)
    data_iterator = dataset.as_numpy_iterator()

    from flax.training import train_state
    # hyperparams, except for the num_z_sample, all the rest are the same as the replication repo

    y_dim = 1
    r_dim = 50
    z_dim = 50
    h_dim = 50
    lr = 3e-4 
    num_z_sample = 32
    z_sample_rng = random.PRNGKey(0)

    from NeuralProcesses.models.model import NeuralProcess
    NP = NeuralProcess(x_dim=1, y_dim=y_dim, r_dim=r_dim, z_dim=z_dim, h_dim=h_dim)
    # initialize model
    init_rng = random.PRNGKey(0)
    params_rng, rng = jax.random.split(init_rng, num=2)
    initial_params = NP.init(params_rng, x_context=np.zeros(shape=(1, 1)), y_context=np.zeros(shape=(1, 1)),
                        x_target=np.zeros(shape=(1, 1)), y_target=np.zeros(shape=(1, 1)), training=True, 
                        sample_size=5, sample_rng = random.PRNGKey(0))['params'] # training=False used for dropoutB
    
    def NP_loss_fn(params, rng, data_tuple):
        """
        Training step of the Nueral Process using reparameterization trick
        """
        variables = {'params': params}
        data_x, data_y, context_mask, target_mask = data_tuple
        partial_model_apply = lambda x_ctx, y_ctx, x_tgt, y_tgt, ctx_msk, tgt_msk: NP.apply(variables, x_context=x_ctx, y_context=y_ctx, x_target=x_tgt, y_target=y_tgt, sample_rng=rng, sample_size=1, training=True, context_mask=ctx_msk, target_mask=tgt_msk)
        y_pred_mu, y_pred_sigma, mu_context, sigma_context, mu_tgt, sigma_tgt = jax.vmap(partial_model_apply, in_axes=(0, 0, 0, 0, 0, 0))(data_x, data_y, data_x, data_y, context_mask, target_mask)

        expanded_aug_y = np.expand_dims(data_y, axis=1)
        MC_avg_target_log_likelihood = distributions.Normal(y_pred_mu, y_pred_sigma).log_prob(expanded_aug_y) * np.expand_dims(np.expand_dims(target_mask, axis=1), axis=-1) # [batch_size, num_samples, num_points, 1]
        MC_avg_target_log_likelihood = np.squeeze(np.mean(np.sum(MC_avg_target_log_likelihood, axis=-2), axis=1), axis=-1) # [batch_size, 1]

        q_context = distributions.MultivariateNormalDiag(mu_context, sigma_context)
        q_all = distributions.MultivariateNormalDiag(mu_tgt, sigma_tgt)
        kl = distributions.kl_divergence(q_all, q_context)
        masked_form_objective =  - np.mean(MC_avg_target_log_likelihood - kl)

        # As a comparison, we do not use mask here, and they should be similar
        context_x = np.stack([_data_x[_context_mask] for _data_x, _context_mask in zip(data_x, context_mask)])
        target_x = np.stack([_data_x[_target_mask] for _data_x, _target_mask in zip(data_x, target_mask)])
        context_y = np.stack([_data_y[_context_mask] for _data_y, _context_mask in zip(data_y, context_mask)])
        target_y = np.stack([_data_y[_target_mask] for _data_y, _target_mask in zip(data_y, target_mask)])



        partial_model_apply = lambda x_ctx, y_ctx, x_tgt, y_tgt: NP.apply(variables, x_context=x_ctx, y_context=y_ctx, x_target=x_tgt, y_target=y_tgt, sample_rng=rng, sample_size=1, training=True)
        y_pred_mu_dynamic, y_pred_sigma_dynamic, mu_context_dynamic, sigma_context_dynamic, mu_tgt_dynamic, sigma_tgt_dynamic = jax.vmap(partial_model_apply, in_axes=(0, 0, 0, 0))(context_x, context_y, target_x, target_y)
        
        p_y_pred_dynamic = distributions.Normal(y_pred_mu_dynamic, y_pred_sigma_dynamic)

        q_context_dynamic = distributions.MultivariateNormalDiag(mu_context_dynamic, sigma_context_dynamic)
        q_all_dynamic = distributions.MultivariateNormalDiag(mu_tgt_dynamic, sigma_tgt_dynamic)
        
        expanded_target_x = np.expand_dims(target_y, axis=1)
        MC_avg_target_log_likelihood_dynamic = p_y_pred_dynamic.log_prob(expanded_target_x) # [batch_size, num_samples, num_points, 1]
        MC_avg_target_log_likelihood_dynamic = np.squeeze(np.mean(np.sum(MC_avg_target_log_likelihood_dynamic, axis=-2), axis=1), axis=-1) # [batch_size, 1]

        kl_dynamic = distributions.kl_divergence(q_all_dynamic, q_context_dynamic)
        naive_form_objective_dynamic =  - np.mean(MC_avg_target_log_likelihood_dynamic - kl_dynamic)

        return masked_form_objective - naive_form_objective_dynamic


    tx = getattr(optax, 'adam')(learning_rate=lr)
    training_state = train_state.TrainState.create(apply_fn=NP.apply, params=initial_params, tx=tx)
    losses = []
    for current_epoch in range(1):
        pbar = tqdm(data_iterator, total=4, initial=0, desc=f"Training Epoch: {current_epoch}")
        for step, data_batch in enumerate(pbar):
            z_sample_rng, z_sample_rng_step = random.split(z_sample_rng, 2)
            z_sample = random.normal(key=z_sample_rng_step, shape=(num_z_sample, z_dim))  
            # conver td data to numpy data
            data_x, data_y = np.split(data_batch, 2, axis=1)
            data_x = np.squeeze(data_x, 1)
            data_y = np.squeeze(data_y, 1)
            rng, uniform_rng = random.split(rng, 2)
            num_context = random.randint(uniform_rng, (1,), 1, 20)[0]
            rng, uniform_rng2 = random.split(rng, 2)
            num_extra_target = random.randint(uniform_rng2, (1,), 1, 20)[0]
            context_mask, target_mask = context_target_mask_gen(rng, data_x.shape[0], data_x.shape[1], num_context, num_extra_target, use_y0=False)

            data_tuples = (data_x, data_y, context_mask, target_mask)
            losses.append(NP_loss_fn(training_state.params, rng, data_tuples))
    assert np.all(np.abs(np.asarray(losses)) < 1e-10) # this threshold will dependent on floating accuracy


def test_NeuralODEProcesses_output_the_same_if_using_mask():
    """
    Integration test of whether NeuralODEProcess return same loss with or without mask 
    """
    # prepare the data
    dataset, _ = _SineData(amplitude_range=(-1., 1.), shift_range=(-.5, .5), num_samples=20).get_data() # K=2000 in algorithm 1
    dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True).batch(batch_size=5)
    data_iterator = dataset.as_numpy_iterator()

    from flax.training import train_state
    # hyperparams, except for the num_z_sample, all the rest are the same as the replication repo
    x_dim =1
    r_dim = 50 # NOT SURE IF MENTIONED IN PAPER
    h_dim = 50  # NOT SURE IF MENTIONED IN PAPER
    lr = 1e-3   # NOT SURE IF MENTIONED IN PAPER
    latent_d_dim = 10
    latent_l_dim = 40  # NOT SURE IF MENTIONED IN PAPER
    max_num_context = 10
    max_num_target = 5
    t0 = -np.pi
    t1 = np.pi

    from NeuralProcesses.models.model import NeuralODEProcess
    NODEP = NeuralODEProcess(x_dim=x_dim, r_dim=r_dim, h_dim=h_dim, latent_l_dim=latent_l_dim, latent_d_dim=latent_d_dim)
    # initialize model
    init_rng = random.PRNGKey(1817)
    params_rng, rng = jax.random.split(init_rng, num=2)
    initial_params = NODEP.init(params_rng, t_context=np.zeros(shape=(1)), x_context=np.zeros(shape=(1, 1)),
                        t_target=np.zeros(shape=(1,)), x_target=np.zeros(shape=(1, 1)), training=True,
                        context_mask = np.ones(shape=(1,)), target_mask = np.ones(shape=(1,)), sample_rng=random.PRNGKey(0), 
                        t0=t0, t1=t1, sample_size=5, init=True)['params'] # training=False used for dropout
    



    def NODEP_loss_fn(params, rng, data_tuple):
        """
        Training step of the Nueral Process using reparameterization trick
        """
        variables = {'params': params}
        # context_t, context_x, target_t, target_x = data_tuple
        data_t, data_x, context_mask, target_mask = data_tuple
        data_t = np.squeeze(data_t, axis=-1)

        sort_indices = np.argsort(data_t, axis=1)
        data_t = np.take_along_axis(data_t, sort_indices, axis=1)
        expanded_sort_indices = np.expand_dims(sort_indices, axis=-1)
        data_x = np.take_along_axis(data_x, expanded_sort_indices, axis=1)
        context_mask = np.take_along_axis(context_mask, sort_indices, axis=1)
        target_mask = np.take_along_axis(target_mask, sort_indices, axis=1)

        batch_model_apply = lambda tctx, x_ctx, t_tgt, x_tgt, mask_ctx, mask_tgt: NODEP.apply(
            variables, t_context=tctx, x_context=x_ctx, t_target=t_tgt, 
            context_mask = mask_ctx, target_mask=mask_tgt, sample_rng = rng, sample_size = 1, 
            x_target=x_tgt, training=True, solver='Dopri5', t0 = t0, t1 = t1)
 
        y_pred_mu, y_pred_sigma, mu_z0_ctx, sigma_z0_ctx, mu_z0_tgt, sigma_z0_tgt, mu_global_ctx, sigma_global_ctx, mu_global_tgt, sigma_global_tgt = \
            jax.vmap(batch_model_apply, in_axes=(0, 0, 0, 0, 0, 0))(data_t, data_x, data_t, data_x, context_mask, target_mask)

        p_y_pred = distributions.Normal(y_pred_mu, y_pred_sigma)
        q_z0_context = distributions.MultivariateNormalDiag(mu_z0_ctx, sigma_z0_ctx)
        q_z0_all = distributions.MultivariateNormalDiag(mu_z0_tgt, sigma_z0_tgt)
        q_global_context = distributions.MultivariateNormalDiag(mu_global_ctx, sigma_global_ctx)
        q_global_all = distributions.MultivariateNormalDiag(mu_global_tgt, sigma_global_tgt)
        expanded_target_x = np.expand_dims(data_x, axis=1)
        MC_avg_target_log_likelihood = p_y_pred.log_prob(expanded_target_x) # [batch_size, num_samples, num_points, 1]
        masked_MC_avg_target_log_likelihood = MC_avg_target_log_likelihood * np.expand_dims(np.expand_dims(target_mask, axis=1), axis=-1)
        MC_avg_target_log_likelihood = np.squeeze(np.mean(np.sum(masked_MC_avg_target_log_likelihood, axis=-2), axis=1), axis=-1) # [batch_size, 1]
        kl_z0 = distributions.kl_divergence(q_z0_all, q_z0_context)
        kl_control = distributions.kl_divergence(q_global_all, q_global_context)
        masked_form_objective = - np.mean(MC_avg_target_log_likelihood - kl_z0 - kl_control)
        
        # As a comparison, we do not use mask here, and they should be similar
        context_t = np.stack([_data_t[_context_mask] for _data_t, _context_mask in zip(data_t, context_mask)])
        target_t = np.stack([_data_t[_target_mask] for _data_t, _target_mask in zip(data_t, target_mask)])
        context_x = np.stack([_data_x[_context_mask] for _data_x, _context_mask in zip(data_x, context_mask)])
        target_x = np.stack([_data_x[_target_mask] for _data_x, _target_mask in zip(data_x, target_mask)])


        sort_indices = np.argsort(context_t, axis=1)
        context_t = np.take_along_axis(context_t, sort_indices, axis=1)
        expanded_sort_indices = np.expand_dims(sort_indices, axis=-1)
        context_x = np.take_along_axis(context_x, expanded_sort_indices, axis=1)

        sort_indices = np.argsort(target_t, axis=1)
        target_t = np.take_along_axis(target_t, sort_indices, axis=1)
        expanded_sort_indices = np.expand_dims(sort_indices, axis=-1)
        target_x = np.take_along_axis(target_x, expanded_sort_indices, axis=1)


        batch_model_apply_dynamic = lambda tctx, x_ctx, t_tgt, x_tgt: NODEP.apply(
            variables, t_context=tctx, x_context=x_ctx, t_target=t_tgt, sample_rng = rng, sample_size = 1, 
            x_target=x_tgt, training=True, solver='Dopri5', t0 = t0, t1 = t1)
        y_pred_mu_dynamic, y_pred_sigma_dynamic, mu_z0_ctx_dynamic, sigma_z0_ctx_dynamic, mu_z0_tgt_dynamic, \
            sigma_z0_tgt_dynamic, mu_global_ctx_dynamic, sigma_global_ctx_dynamic, mu_global_tgt_dynamic, \
                sigma_global_tgt_dynamic = jax.vmap(batch_model_apply_dynamic, in_axes=(0, 0, 0, 0))(\
                    context_t, context_x, target_t, target_x)
        p_y_pred_dynamic = distributions.Normal(y_pred_mu_dynamic, y_pred_sigma_dynamic)

        q_z0_context_dynamic = distributions.MultivariateNormalDiag(mu_z0_ctx_dynamic, sigma_z0_ctx_dynamic)
        q_z0_all_dynamic = distributions.MultivariateNormalDiag(mu_z0_tgt_dynamic, sigma_z0_tgt_dynamic)
        q_global_context_dynamic = distributions.MultivariateNormalDiag(mu_global_ctx_dynamic, sigma_global_ctx_dynamic)
        q_global_all_dynamic = distributions.MultivariateNormalDiag(mu_global_tgt_dynamic, sigma_global_tgt_dynamic)
        
        expanded_target_x = np.expand_dims(target_x, axis=1)
        MC_avg_target_log_likelihood_dynamic = p_y_pred_dynamic.log_prob(expanded_target_x) # [batch_size, num_samples, num_points, 1]
        MC_avg_target_log_likelihood_dynamic = np.squeeze(np.mean(np.sum(MC_avg_target_log_likelihood_dynamic, axis=-2), axis=1), axis=-1) # [batch_size, 1]

        kl_z0_dynamic = distributions.kl_divergence(q_z0_all_dynamic, q_z0_context_dynamic)
        kl_control_dynamic = distributions.kl_divergence(q_global_all_dynamic, q_global_context_dynamic)
        naive_form_objective_dynamic =  - np.mean(MC_avg_target_log_likelihood_dynamic - kl_z0_dynamic - kl_control_dynamic)

        # compare differience
        # there are subtle difference on NODEP.tx_to_z_dist, but after carechecking think this is because of neurical stability
        # as increase to float64 will just reduce this difference to 1e-16 level 
        mu_z0_ctx_diff = mu_z0_ctx - mu_z0_ctx_dynamic
        sigma_z0_ctx_diff = sigma_z0_ctx - sigma_z0_ctx_dynamic
        mu_global_ctx_diff = mu_global_ctx - mu_global_ctx_dynamic
        sigma_global_ctx_diff = sigma_global_ctx - sigma_global_ctx_dynamic
        mu_z0_all_diff = mu_z0_tgt - mu_z0_tgt_dynamic
        sigma_z0_all_diff = sigma_z0_tgt - sigma_z0_tgt_dynamic
        mu_global_all_dff = mu_global_tgt - mu_global_tgt_dynamic
        sigma_global_all_diff = sigma_global_tgt - sigma_global_tgt_dynamic

        kl_z0_diff = kl_z0 - kl_z0_dynamic
        kl_control_diff = kl_control - kl_control_dynamic
        diff = masked_form_objective - naive_form_objective_dynamic
        return diff

    tx = optax.rmsprop(learning_rate=lr)
    training_state = train_state.TrainState.create(apply_fn=NODEP.apply, params=initial_params, tx=tx)
    pbar = tqdm(data_iterator, total=4, initial=0)
    losses = []
    for _, data_batch in enumerate(pbar):
        data_x, data_y = np.split(data_batch, 2, axis=1)
        data_x = np.squeeze(data_x, 1)
        data_y = np.squeeze(data_y, 1)
        rng, uniform_rng = random.split(rng, 2)
        num_context = random.randint(uniform_rng, (1,), 1, max_num_context)[0]
        rng, uniform_rng2 = random.split(rng, 2)
        num_extra_target = random.randint(uniform_rng2, (1,), 0, max_num_target)[0]
        context_mask, target_mask = context_target_mask_gen(rng, data_x.shape[0], data_x.shape[1], num_context, num_extra_target, use_y0=False)
        data_tuples = (data_x, data_y, context_mask, target_mask)
        losses.append(NODEP_loss_fn(training_state.params, rng, data_tuples))
    assert np.all(np.abs(np.asarray(losses)) < 1e-10) # this threshold will dependent on floating accuracy





if __name__ == '__main__':
    test_NeuralProcess_output_the_same_if_using_mask()
    test_NeuralODEProcesses_output_the_same_if_using_mask()