"""
Test of loss
"""
import random
from jax import random
from jax import numpy as np
from NeuralProcesses.training.losses import GetNeuralODEProcessMFVILossAcceptMultiTrajData, GetNeuralODEProcessNLLLossAcceptBatchData
# from NeuralProcesses.models.model import NeuralODEProcessAcceptMultiTimeSeriesDataButTakeOneOnly
from NeuralProcesses.data import datasets
import ml_collections
import jax

dummy_config = ml_collections.ConfigDict()
dummy_config.seed = 0
dummy_config.data = ml_collections.ConfigDict()
dummy_config.data.aux_eval_metric = {'MSE': 'SANODEPTestMSELossAcceptBatchData', 'NLL': 'SANODEPNLLLossAcceptBatchData'} 
dummy_config.data.dataset_name = 'LOTKA_VOLTERRA_ODE'
dummy_config.data.shuffle_buffer_size = 1000
dummy_config.data.num_context_range = (1, 15)
dummy_config.data.use_initial = True # note this is the defualt choice used in NODEP repository, refer TimeNeuralProcessTrainer class
dummy_config.data.num_extra_target_range = (0, 45)
dummy_config.data.known_traj_range = (1, 20)
dummy_config.data.foracsting_problem_prob = 0.5 # 1.0: always do forcasting problem, 0.0: always do regression problem 
dummy_config.data.args = ml_collections.ConfigDict()
dummy_config.data.args.data_gen_rng = random.PRNGKey(0)
dummy_config.data.args.dynamics_smp_num = 20
dummy_config.data.args.initial_condition_smp_num = 100
dummy_config.data.args.num_timesteps = 100

dummy_config.data.args.t_range = (0, 1.5)
dummy_config.data.args.x_0_range =  ((0.1, 0.1), (3.0, 3.0))
dummy_config.data.args.alpha_range = (1/3, 1.0) 
dummy_config.data.args.beta_range = (1.0, 2.0) 
dummy_config.data.args.delta_range = (0.5, 1.5) 
dummy_config.data.args.gamma_range = (0.5, 1.5) 
dummy_config.data.args.generator = True

dummy_config.evaluation = ml_collections.ConfigDict()
dummy_config.evaluation.batch_size = 10
dummy_config.evaluation.num_steps = 1000
dummy_config.evaluation.rng = random.PRNGKey(1000)


def test_getneuralodeprocessmfvilossacceptmultitrajdata_masked_log_likelihood_is_the_same_as_getneuralodeprocessnlllossacceptbatchdata():
    # Define model, training, and config
    rng = random.PRNGKey(0)
    # create data
    config = ml_collections.ConfigDict()
    config.model = ml_collections.ConfigDict()
    config.model.name = 'NeuralODEProcessAcceptMultiTimeSeriesDataAsAblation'
    config.model.cfg_args = ml_collections.ConfigDict()
    config.model.cfg_args.x_dim = 1
    config.model.cfg_args.r_dim = 50
    config.model.cfg_args.encoder_h_dim = 50
    config.model.cfg_args.ode_layer_h_dim = 50
    config.model.cfg_args.decoder_h_dim = 50
    config.model.cfg_args.latent_d_dim = 90
    config.model.cfg_args.latent_l_dim = 10
    config.model.autonomous_ode = False
    config.model.sample_size = 1
    config.model.t0 = - np.pi # refer initial t in their decoder
    config.model.t1 = np.pi

    # model init args
    config.model.init_args = ml_collections.ConfigDict()
    config.model.init_args.t_context = np.zeros(shape=(1, 1))
    config.model.init_args.x_context = np.zeros(shape=(1, 1, 1))
    config.model.init_args.t_target = np.zeros(shape=(1, 1))
    config.model.init_args.x_target = np.zeros(shape=(1, 1, 1))
    config.model.init_args.context_mask = np.ones(shape=(1, 1))
    config.model.init_args.target_mask = np.ones(shape=(1, 1))
    config.model.init_args.t0 = config.model.t0 
    config.model.init_args.t1 = config.model.t1
    config.model.init_args.training = True
    config.model.init_args.sample_size = 1
    config.model.init_args.sample_rng = random.PRNGKey(0)


    config.data = ml_collections.ConfigDict()
    config.data.aux_eval_metric = {'MSE': 'NeuralODEProcessTestMSELossAcceptBatchData', 'NLL': 'NeuralODEProcessNLLLossAcceptBatchData'} # , 
                                   # 'CE': 'NeuralODEProcessEnsembleCELoss', 'Sharp': 'NeuralODEProcessSharpnessLoss'}
    config.data.aux_cfg = {'CE': {'levels': 10}}
    config.data.dataset_name = 'SineData'
    config.data.shuffle_buffer_size = 1000    
    config.data.num_context_range = (1, 25)
    config.data.num_extra_target_range = (0, 75)
    config.data.args = ml_collections.ConfigDict()
    config.data.args.data_gen_rng = random.PRNGKey(0)
    config.data.args.amplitude_range = (1.0, 1.0)
    config.data.args.shift_range= (-np.pi, np.pi)
    config.data.args.omega_range = (np.pi / 4, np.pi)
    config.data.args.dynamics_smp_num = 1000
    config.data.args.initial_condition_smp_num = 30
    # config.data.args.num_train_samples = 500
    config.data.args.aux = {'MSE': (20, 10), 'NLL': (20, 10)} # , 'CE': 10, 'Sharp': 10}
    config.data.use_initial = True # note this is the defualt choice used in NODEP repository, refer TimeNeuralProcessTrainer class


    rng, init_rng = jax.random.split(rng, num=2)
    model = NeuralODEProcessAcceptMultiTimeSeriesDataButTakeOneOnly(**config.model.cfg_args)
    params_rng, rng = jax.random.split(init_rng, num=2)
    variables = model.init(params_rng, **config.model.init_args) 
    initial_model_state, initial_params = variables.pop('params')

    data, dataset_inst = datasets.get_dataset(config)
    data = data.shuffle(buffer_size=config.data.shuffle_buffer_size, seed=int(1), 
                        reshuffle_each_iteration=True).batch(batch_size=10)
    data_iterator = data.as_numpy_iterator()
    pre_processor = datasets.get_data_preprocessor(dataset_inst, config)


    for data_batch in data_iterator:
        processed_data, rng = pre_processor(data_batch, rng)

        # Get the functions
        mfvi_loss_func = GetNeuralODEProcessMFVILossAcceptMultiTrajData(model=model, training=True, config=config, return_each_component=True)
        nll_loss_func = GetNeuralODEProcessNLLLossAcceptBatchData(model=model, training=False, config=config)

        # Call the functions
        mc_log_likelihood1, _, _, _ = mfvi_loss_func(rng, initial_params, processed_data)
        mc_log_likelihood2, _ = nll_loss_func(rng, initial_params, processed_data)

        # Check if the values are close enough
        assert np.allclose(mc_log_likelihood1, mc_log_likelihood2), "The MC_avg_target_log_likelihood values are not the same"
        break



def test_gp_mse_is_the_same_as_meta_learn_ode_calculation():
    """
    This is a test to make sure that GP predict MSE is calculated the same as the NODEP, SANODEP, NP
    to avoid any mistake in model comparison

    Since the loss calculation flow of GP and meta learend model is fundamentally different:
    GP uses the 
    it is rather hard to design a modulized test for this. Hence, as a dirty workaound, we
    paste the code for the loss calculation of GP and meta learned model here and compare them
    """
    sys_id  = 0 
    
    import tensorflow as tf
    from tqdm import tqdm
    from tensorflow_probability import distributions as tfd
    from trieste.data import Dataset as trieste_Dataset

    from NeuralProcesses.models.gpflow.builder import \
        build_stacked_independent_objectives_model

    # Build data iterators
    rng = jax.random.PRNGKey(dummy_config.seed)
    rng, shuffle_rng = jax.random.split(rng)
    data, dataset_inst = datasets.get_dataset(dummy_config)
    # the order of shuffle, batch and repeat shall follows strictly as states here: https://stackoverflow.com/questions/49915925/output-differences-when-changing-order-of-batch-shuffle-and-repeat#:~:text=Best%20Ordering%3A&text=For%20batches%20to%20be%20different,are%20unique%2C%20unlike%20the%20other.

    # Create data normalizer and its inverse
    pre_processor = datasets.get_data_preprocessor(dataset_inst, dummy_config)

    data_iterator = data.as_numpy_iterator()

    pbar = tqdm(
        data_iterator,
        total=dummy_config.evaluation.num_steps,
        initial=0,
        desc=f"Evaluation Step",
    )  # use int to convert jitted step (array) back that is acceptable for tqdm
    for iter, data_batch in enumerate(pbar):
        if iter == dummy_config.evaluation.num_steps:
            break
        processed_data, rng = pre_processor(
            data_batch,
            rng,
            known_traj_range=dummy_config.data.known_traj_range,
            all_as_target=True,
        )

        (
            data_t,
            data_x,
            _,
            _,
            _,
            ctx_mask_with_new_traj_obs,
            ctx_mask_with_new_traj_target_mask,
            _,
            _,
            _,
        ) = processed_data
        original_ctx_mask_with_new_traj_target_mask = ctx_mask_with_new_traj_target_mask
        ctx_mask_with_new_traj_obs = ctx_mask_with_new_traj_obs[sys_id] # [num_traj, num_traj, timesteps]
        ctx_mask_with_new_traj_target_mask = ctx_mask_with_new_traj_target_mask[sys_id] # [num_traj, num_traj, timesteps]
        data_t = data_t[sys_id]
        data_x = data_x[sys_id]

        # maybe we do not want to do this whole big loop!
        num_states = data_x.shape[-1]

        metrics = {"mse": [], "nll": []}

        # FIXME: This For loop is somewhat ill-designed, we should not loop through the trajectories
        # loop through trajectories to focus
        for traj_idx, (
            ctx_mask_with_new_traj_obs_single_traj,
            ctx_mask_with_new_traj_target_mask_single_traj,
        ) in tqdm(
            enumerate(
                zip(
                    ctx_mask_with_new_traj_obs,
                    # FIXME: 6:45 我们发现GP的这个地方实现有点问题，这里应该用对角元，而不是这样的for loop,不然总在第一个traj里面
                    # ctx_mask_with_new_traj_target_mask,
                    np.diagonal(ctx_mask_with_new_traj_target_mask, axis1=-3, axis2=-2),
                )
            )
        ):
            # training data
            augmented_input_init_cond = np.repeat(
                data_x[:, :1, :], # note thia :1 means the initial condition
                axis=1,
                repeats=data_x.shape[1],
            )
            augmented_input = np.concatenate(
                [data_t, augmented_input_init_cond], axis=-1
            )
            dataset = trieste_Dataset(
                tf.cast(
                    augmented_input[ctx_mask_with_new_traj_obs_single_traj],
                    dtype=tf.float64,
                ),
                tf.cast(
                    data_x[ctx_mask_with_new_traj_obs_single_traj],
                    dtype=tf.float64,
                ),
            )

            # train the GP model

            models = build_stacked_independent_objectives_model(
                dataset, _num_states=num_states
            )
            # The gp can fail to optimize
            models._models[0]._model.kernel.lengthscales = tf.constant([0.3], dtype=tf.float64)
            models._models[0]._model.kernel.variance = tf.constant([1.0], dtype=tf.float64)
            models._models[0]._model.likelihood.variance = tf.constant([1e-3], dtype=tf.float64)
            models._models[1]._model.kernel.lengthscales = tf.constant([1.0], dtype=tf.float64)
            models._models[1]._model.kernel.variance = tf.constant([1.0], dtype=tf.float64)
            models._models[1]._model.likelihood.variance = tf.constant([1e-3], dtype=tf.float64)

            test_dataset = trieste_Dataset(
                tf.cast(
                    augmented_input[traj_idx][
                        ctx_mask_with_new_traj_target_mask_single_traj
                    ],
                    dtype=tf.float64,
                ),
                tf.cast(
                    data_x[traj_idx][
                        ctx_mask_with_new_traj_target_mask_single_traj
                    ],
                    dtype=tf.float64,
                ),
            )

            gp_mean, gp_var = models.predict(test_dataset.query_points)
            
            # calculate MSE and negative log likelihood
            # [timesteps, state_dim] -> 1 through averaging across state_dim and time_dim
            mse = ((gp_mean - test_dataset.observations) ** 2).numpy().mean()

            # we note that echo how the NLL is calculated in meta learning, we calculate the NLL here across trajectory as well
            # which is the summation across state and time dimensions
            nll = np.sum(
                -tfd.Normal(gp_mean, tf.sqrt(gp_var))
                .log_prob(test_dataset.observations)
                .numpy(),
                axis=[-1, -2],
            ) # [time_steps, state_dim] -> 1
            if np.isfinite(mse):
                metrics["mse"].append(mse)
            if np.isfinite(nll):
                metrics["nll"].append(nll)
            print(
                f"mse: {mse}, nll: {nll}, mse_mean: {np.mean(np.asarray(metrics['mse']))}, nll_mean: {np.mean(np.asarray(metrics['nll']))}"
            )

        GP_mse_mean = np.mean(np.asarray(metrics["mse"]))
        NLL_mse_mean = np.mean(np.asarray(metrics["nll"]))



        # ==========================    META LEARN MODEL PREDICTION STYLE    ==========================
        # META LEARN MODEL PREDICTION STYLE

        dummy_ml_means = []
        dummy_ml_vars = []
        test_data_obs_list = []
        test_data_masks = []
        for traj_idx, (
            ctx_mask_with_new_traj_obs_single_traj,
            ctx_mask_with_new_traj_target_mask_single_traj,
        ) in tqdm(
            enumerate(
                zip(
                    ctx_mask_with_new_traj_obs,
                    np.diagonal(ctx_mask_with_new_traj_target_mask, axis1=-3, axis2=-2),
                )
            )
        ):
            print(f"Trajectory {traj_idx}")
            # training data
            augmented_input_init_cond = np.repeat(
                data_x[:, :1, :], # note thia :1 means the initial condition
                axis=1,
                repeats=data_x.shape[1],
            )
            augmented_input = np.concatenate(
                [data_t, augmented_input_init_cond], axis=-1
            )
            # FIXME: 

            dataset = trieste_Dataset(
                tf.cast(
                    augmented_input[ctx_mask_with_new_traj_obs_single_traj],
                    dtype=tf.float64,
                ),
                tf.cast(
                    data_x[ctx_mask_with_new_traj_obs_single_traj],
                    dtype=tf.float64,
                ),
            )

            # train the GP model

            models = build_stacked_independent_objectives_model(
                dataset, _num_states=num_states
            )
            models._models[0]._model.kernel.lengthscales = tf.constant([0.3], dtype=tf.float64)
            models._models[0]._model.kernel.variance = tf.constant([1.0], dtype=tf.float64)
            models._models[0]._model.likelihood.variance = tf.constant([1e-3], dtype=tf.float64)
            models._models[1]._model.kernel.lengthscales = tf.constant([1.0], dtype=tf.float64)
            models._models[1]._model.kernel.variance = tf.constant([1.0], dtype=tf.float64)
            models._models[1]._model.likelihood.variance = tf.constant([1e-3], dtype=tf.float64)
# 
            # Approach 2: Meta Learn based approach
            # GP handle approached
            test_dataset = trieste_Dataset(
                tf.cast(
                    augmented_input[traj_idx],
                    dtype=tf.float64,
                ),
                tf.cast(
                    data_x[traj_idx],
                    dtype=tf.float64,
                ),
            )
# 
            dummy_ml_mean, dummy_ml_var = models.predict(test_dataset.query_points)
            test_data_obs = test_dataset.observations
            dummy_ml_means.append(dummy_ml_mean)
            dummy_ml_vars.append(dummy_ml_var)
            test_data_obs_list.append(test_data_obs)

        dummy_ml_means = np.expand_dims(np.array(dummy_ml_means)[None, ...], axis=2) # [sys_num, num_traj, 1, timesteps, state_dim]
        dummy_ml_vars = np.expand_dims(np.array(dummy_ml_vars)[None, ...], axis=2) # [sys_num, num_traj, 1, timesteps, state_dim]
        test_data_obs_list = np.array(test_data_obs_list)[None, ...] # [sys_num, num_traj, timesteps, state_dim]
        expanded_aug_x = np.expand_dims(test_data_obs_list, axis=2) # [sys_num, num_traj, MC size, timesteps, state_dim]
        # calculate MSE and negative log likelihood
        # dummy_ml_means [sys_num, num_traj, mc_size, timesteps, state_dim]
        # expanded_aug_x [sys_num, num_traj, 1, timesteps, state_dim]
        # ctx_mask_with_new_traj_obs [sys_num, num_traj, timesteps]
        # target_mask: [sys_num, num_traj, timesteps]
        # [20, 100, 100]
        target_mask = np.diagonal(original_ctx_mask_with_new_traj_target_mask, axis1=-3, axis2=-2)
        element_wise_mse = ((dummy_ml_means - expanded_aug_x) ** 2) * np.expand_dims(np.expand_dims(target_mask[:1], axis=2), axis=-1)
        ml_mse = np.sum(np.mean(element_wise_mse, axis=-1),axis=-1) / np.maximum(np.expand_dims(np.sum(target_mask[:1], axis=-1), axis=-1), 1.0)

        ml_mse_traj_mc_wise_mean = np.mean(ml_mse, axis=[-1, -2])
        # ml_mse_traj_mc_wise_mean should be the same as GP_mse_mean
        print(f'Meta_learn_MSE_Mean: np.mean(np.squeeze(ml_mse)), GP_MSE_MEAN: {np.mean(np.asarray(metrics["mse"]))}')

        x_pred_f = dummy_ml_means
        x_pred_sigma = np.sqrt(dummy_ml_vars)
        from tensorflow_probability import distributions as distributions
        MC_avg_target_log_likelihood = distributions.Normal(x_pred_f, x_pred_sigma).log_prob(expanded_aug_x) # (20, 100, 32, 100, 2)
        # [batch_size, traj_size, num_samples, num_points, 1]
        # target_mask: [batch_size, traj_size, num_points]
        MC_avg_target_log_likelihood = MC_avg_target_log_likelihood * \
            np.expand_dims(np.expand_dims(target_mask[:1], axis=2), axis=-1).astype(data_x.dtype)
        MC_avg_target_log_likelihood = np.mean(
            np.sum(MC_avg_target_log_likelihood.numpy(), axis=[-1, -2]), axis=[-1, -2]
        )  # [batch_size]
        meta_learn_nll = -MC_avg_target_log_likelihood

        assert np.allclose(GP_mse_mean, ml_mse_traj_mc_wise_mean, rtol=1E-6), "The GP MSE is not the same as the meta learned MSE"
        assert np.allclose(NLL_mse_mean, meta_learn_nll, rtol=1e-6), "The GP NLL is not the same as the meta learned NLL"


def test_gp_nll_is_the_same_as_meta_learn_ode_calculation():
    """
    This is a test to make sure that GP predict NLL is calculated the same as the NODEP, SANODEP, NP
    to avoid any mistake in model comparison
    """


if __name__ ==  '__main__':
    # test_getneuralodeprocessmfvilossacceptmultitrajdata_masked_log_likelihood_is_the_same_as_getneuralodeprocessnlllossacceptbatchdata()
    test_gp_mse_is_the_same_as_meta_learn_ode_calculation()