from jax import random
from jax import numpy as np
from NeuralProcesses.data.utils import context_target_mask_gen


def test_context_target_mask_gen_when_use_initial_is_False():
    rng = random.PRNGKey(0)
    batch_size = 2
    num_points = 100
    known_traj = np.array([1, 2])
    traj_size = 3
    num_context = np.array([[3, 4, 5], [2, 1, 3]])
    num_extra_target = np.array([[2, 3, 2], [3, 0, 2]])
    use_initial = False

    context_mask, target_mask = context_target_mask_gen(rng, batch_size, known_traj, traj_size,
                                                        num_points, num_context, num_extra_target, use_initial)

    # Check if the shapes of the returned masks are correct
    assert context_mask.shape == (batch_size, traj_size, num_points)
    assert target_mask.shape == (batch_size, traj_size, num_points)

    # Check if the number of True values in each row of the context_mask is equal to num_context
    # assert np.all(np.sum(context_mask, axis=-1)[:] == num_context)
    for _context_mask, _known_traj, _num_context in zip(context_mask, known_traj, num_context):
        assert np.all(_context_mask[:_known_traj].sum(-1) == _num_context[:_known_traj])
        # for the rest of screend out traj, make sure there is no mask happens
        if np.size(_context_mask[_known_traj:].sum(-1)) != 0 : # tricky to ensure the size is not 0 otherwise assertion issue
            assert np.all(_context_mask[_known_traj:].sum(-1) == np.array([0]))  

    # Check if the number of True values in each row of the target_mask is equal to num_context + num_extra_target
    for _target_mask, _known_traj, _num_context, _num_extra_target in zip(target_mask, known_traj, num_context, num_extra_target):
        assert np.all(_target_mask[:_known_traj].sum(-1) == (_num_context + _num_extra_target)[:_known_traj])
        # for the rest of screend out traj, make sure there is no mask happens
        if np.size(_target_mask[_known_traj:].sum(-1)) != 0 : # tricky to ensure the size is not 0 otherwise assertion issue
            assert np.all(_target_mask[_known_traj:].sum(-1) == np.array([0]))  

    # Check if all the context mask belongs to the target mask
    assert np.all(np.logical_and(context_mask, target_mask) == context_mask)


def test_context_target_mask_gen_when_use_initial_is_True():
    rng = random.PRNGKey(0)
    batch_size = 2
    num_points = 100
    known_traj = np.array([1, 2])
    traj_size = 3
    num_context = np.array([[3, 4, 5], [2, 1, 3]])
    num_extra_target = np.array([[2, 3, 2], [3, 0, 2]])
    # check use initial condition
    use_initial = True

    context_mask, target_mask = context_target_mask_gen(rng, batch_size, known_traj, traj_size,
                                                        num_points, num_context, num_extra_target, use_initial)

    # Check if the shapes of the returned masks are correct
    assert context_mask.shape == (batch_size, traj_size, num_points)
    assert target_mask.shape == (batch_size, traj_size, num_points)

    # Check if the number of True values in each row of the context_mask is equal to num_context
    # assert np.all(np.sum(context_mask, axis=-1)[:] == num_context)
    for _context_mask, _known_traj, _num_context in zip(context_mask, known_traj, num_context):
        assert np.all(_context_mask[:_known_traj].sum(-1) == _num_context[:_known_traj])
        assert np.all(_context_mask[:_known_traj][:, 0]  == True) # make sure the initial condition is set to True
        # for the rest of screend out traj, make sure there is no mask happens
        if np.size(_context_mask[_known_traj:].sum(-1)) != 0 : # tricky to ensure the size is not 0 otherwise assertion issue
            assert np.all(_context_mask[_known_traj:].sum(-1) == np.array([0]))  

    # Check if the number of True values in each row of the target_mask is equal to num_context + num_extra_target
    for _target_mask, _known_traj, _num_context, _num_extra_target in zip(target_mask, known_traj, num_context, num_extra_target):
        assert np.all(_target_mask[:_known_traj].sum(-1) == (_num_context + _num_extra_target)[:_known_traj])
        assert np.all(_target_mask[:_known_traj][:, 0]  == True)  # make sure the initial condition is set to True
        # for the rest of screend out traj, make sure there is no mask happens
        if np.size(_target_mask[_known_traj:].sum(-1)) != 0 : # tricky to ensure the size is not 0 otherwise assertion issue
            assert np.all(_target_mask[_known_traj:].sum(-1) == np.array([0]))  

    # Check if all the context mask belongs to the target mask
    assert np.all(np.logical_and(context_mask, target_mask) == context_mask)


def test_context_target_mask_gen_target_is_all_true_when_num_context_and_num_target_sum_up_to_timesteps_use_initial_is_False():
    rng = random.PRNGKey(0)
    batch_size = 1
    known_traj = np.array([1])
    traj_size = 1
    num_timesteps = 10
    num_context = np.array([[5]])
    num_extra_target = np.array([[5]])
    use_initial = False

    context_mask, target_mask = context_target_mask_gen(rng, batch_size, known_traj, traj_size, num_timesteps, num_context, num_extra_target, use_initial)
    assert np.all(target_mask == True), "All target mask should be True when num_context + num_extra_target = num_timesteps and use_initial is False"


def test_context_target_mask_gen_target_is_all_true_when_num_context_and_num_target_sum_up_to_timesteps_use_initial_is_True():
    rng = random.PRNGKey(0)
    batch_size = 1
    known_traj = np.array([1])
    traj_size = 1
    num_timesteps = 10
    num_context = np.array([[5]])
    num_extra_target = np.array([[5]])
    use_initial = True

    context_mask, target_mask = context_target_mask_gen(rng, batch_size, known_traj, traj_size, num_timesteps, num_context, num_extra_target, use_initial)
    assert np.all(target_mask == True), "All target mask should be True when num_context + num_extra_target = num_timesteps and use_initial is True"