"""Test the external memory"""

import jax
import jax.numpy as jnp

from memento.memory.external_memory import ExternalMemory, n_update_memory

if __name__ == "__main__":
    # create the external memory
    memory_size = 50
    key_dim = 2
    value_dim = 3

    horizon = 5
    batch_size = 4

    # create dummy data to init the memory
    dummy_key = jnp.zeros((key_dim,), dtype=jnp.float32)
    dummy_value = jnp.zeros((value_dim,), dtype=jnp.float32)
    memory = ExternalMemory.create(memory_size, dummy_key, dummy_value)

    print("Memory : ")
    print("Keys : ", memory.keys)
    print("Values : ", memory.values)

    # create some data ot be added
    keys = jnp.ones((batch_size, horizon, key_dim), dtype=jnp.float32)
    values = jnp.ones((batch_size, horizon, value_dim), dtype=jnp.float32)
    is_done = jnp.zeros((batch_size, horizon), dtype=jnp.int32)

    # insert the data
    memory = n_update_memory(memory, keys, values, is_done)

    print("Memory : ")
    print("Keys : ", memory.keys)
    print("Values : ", memory.values)


    # create some data ot be added
    keys = jnp.ones((batch_size, horizon, key_dim), dtype=jnp.float32)*2
    values = jnp.ones((batch_size, horizon, value_dim), dtype=jnp.float32)*2
    is_done = jnp.zeros((batch_size, horizon), dtype=jnp.int32)

    # insert the data
    memory = n_update_memory(memory, keys, values, is_done)

    print("Memory : ")
    print("Keys : ", memory.keys)
    print("Values : ", memory.values)


    # create some data ot be added
    keys = jnp.ones((batch_size, horizon, key_dim), dtype=jnp.float32)*3
    values = jnp.ones((batch_size, horizon, value_dim), dtype=jnp.float32)*3
    is_done = jnp.zeros((batch_size, horizon), dtype=jnp.int32)

    # insert the data
    memory = n_update_memory(memory, keys, values, is_done)

    print("Memory : ")
    print("Keys : ", memory.keys)
    print("Values : ", memory.values)
