from jax.config import config; config.update("jax_enable_x64", True)
from NeuralProcesses.models.utils.nn import CrossAttention
from jax import numpy as np


# def test_scaled_dot_product_attention_works_with_mask():
#     from jax import random
#     rng = random.PRNGKey(0)
#     keys = random.normal(rng, shape=(10, 50))
#     rng, q_rng = random.split(rng, 2)
#     queries = random.normal(q_rng, shape=(15, 50))
#     rng, v_rng = random.split(rng, 2)
#     values = random.normal(v_rng, shape=(10, 50))
#     Q_mask = np.any(random.choice(rng, 15, shape=(7,), replace=False)[..., None] == np.arange(15), 0)
# 
#     K_mask = np.any(random.choice(rng, 10, shape=(5,), replace=False)[..., None] == np.arange(10), 0)
#     A = _scaled_dot_product_attention(queries, keys, values, Q_mask, K_mask)
#     B = _scaled_dot_product_attention(queries[Q_mask], keys[K_mask], values[K_mask], 
#                                       np.ones((7,)).astype(np.bool_), np.ones((5,)).astype(np.bool_))
# 
#     assert np.sum(np.abs(A[Q_mask] - B)) < 1E-10

def test_cross_attention_works_with_mask():
    """
    This test is to make sure that context & target mask is properlly implemented in cross attention block
    """
    from jax import random
    rng = random.PRNGKey(0)
    blk = CrossAttention(in_ch = 128, num_heads=8)
    keys = random.normal(rng, shape=(10, 50))
    rng, q_rng = random.split(rng, 2)
    queries = random.normal(q_rng, shape=(15, 50))
    rng, v_rng = random.split(rng, 2)
    values = random.normal(v_rng, shape=(10, 50))
    Q_mask = np.any(random.choice(rng, 15, shape=(7,), replace=False)[..., None] == np.arange(15), 0)

    K_mask = np.any(random.choice(rng, 10, shape=(5,), replace=False)[..., None] == np.arange(10), 0)

    variables = blk.init(random.PRNGKey(0), queries, Q_mask, keys, K_mask, values) 
    
    A = blk.apply(variables, queries, Q_mask, keys, K_mask, values)
    B = blk.apply(variables, queries[Q_mask], np.ones((7,)).astype(np.bool_), keys[K_mask], np.ones((5,)).astype(np.bool_), values[K_mask])

    assert np.sum(np.abs(A[Q_mask] - B)) < 1E-10

if __name__ ==  '__main__':
    # from flax.linen import Dense, Module, silu, softmax
    # a = np.ones(shape=(3,))
    # debug_a = softmax(a)
    # debug_b = softmax(a, where=np.array([False, True, True]), initial=0.0)
    # test_scaled_dot_product_attention_works_with_mask()
    test_cross_attention_works_with_mask()