"""In this file, I am going to test the memory decoder.

1. Create the network.
2. Create the external memory.
3. Create the observation.
4. Create the embeddings.
5. Call the network.



"""

from typing import Any

import acme
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import rlax
from typing_extensions import TypeAlias

from memento import environments, networks
from memento.memory.external_memory import ExternalMemory, update_memory
from memento.utils.data import generate_zeros_from_spec

Observation: TypeAlias = Any


if __name__ == "__main__":
    # define an environment
    environment = environments.MementoTSP(num_cities=100)

    # retrieve spec of the environment
    environment_spec = acme.make_environment_spec(environment)

    # dummy obs used to create the models params
    _dummy_obs = environment.make_observation(
        *jax.tree_util.tree_map(
            generate_zeros_from_spec,
            environment_spec.observations.generate_value(),
        )
    )

    # init a random key
    random_key = jax.random.PRNGKey(0)

    # get embedding
    # random_key, subkey = jax.random.split(random_key)
    # embeddings = jax.random.uniform(subkey, minval=-1.0, maxval=1.0)

    def encoder_fn(problem: chex.Array):
        # use the class/params given in config to instantiate an encoder
        encoder = networks.tsp.TSPEncoder(
            num_layers=6,
            num_heads=8,
            key_size=16,
            expand_factor=4,
            model_size=128,
            name="shared_encoder",
        )
        return encoder(problem)

    def decoder_fn(observation: Observation, embeddings: chex.Array):
        # use the class/params given in config to instantiate decoder
        decoder = networks.tsp.TSPDecoder(
            num_heads=8, key_size=16, model_size=128, name="decoder"
        )
        return decoder(observation, embeddings)

    def memory_decoder_fn(
        observation: Observation,
        embeddings: chex.Array,
        external_memory: ExternalMemory,
    ):
        # use the class/params given in config to instantiate decoder
        decoder = networks.tsp.TSPMemoryDecoder(
            num_heads=8, key_size=16, model_size=128, name="decoder"
        )
        return decoder(observation, embeddings, external_memory)

    # classic encoder/decoder
    encoder_fn = hk.without_apply_rng(hk.transform(encoder_fn))
    decoder_fn = hk.without_apply_rng(hk.transform(decoder_fn))

    # memory decoder
    memory_decoder_fn = hk.without_apply_rng(hk.transform(memory_decoder_fn))

    # create an encoder
    random_key, subkey = jax.random.split(random_key)
    encoder_params = encoder_fn.init(subkey, _dummy_obs.problem)

    encoder_structure = jax.tree_util.tree_structure(encoder_params)

    print("Encoder structure: ", encoder_structure)

    jax.tree_util.tree_map(lambda x: print(x.shape), encoder_params)

    # print("Encoder params: ", encoder_params)

    # get a dummy embedding to init the decoder
    _dummy_embeddings = encoder_fn.apply(encoder_params, _dummy_obs.problem)

    # create a decoder
    random_key, subkey = jax.random.split(random_key)
    decoder_params = decoder_fn.init(subkey, _dummy_obs, _dummy_embeddings)

    # try to do a few steps in the env - as a sanity check

    # get a problem
    print("Create a problem")
    random_key, subkey = jax.random.split(random_key)
    problem = environment.generate_problem(subkey, environment.get_problem_size())

    # get a starting position
    print("Create a starting position")
    random_key, subkey = jax.random.split(random_key)
    start_position = jax.random.randint(
        subkey,
        (1,),
        minval=environment.get_min_start(),
        maxval=environment.get_max_start() + 1,
    )[0]

    # initialise the embeddings for each problem
    print("Get embeddings from the encoder")
    embeddings = encoder_fn.apply(encoder_params, problem)

    # reset
    state, timestep = environment.reset_from_state(problem, start_position)

    print("Do steps in the environment")

    # store contexts and actions
    contexts = []
    chosen_actions = []
    for i in range(20):
        logits, _ = decoder_fn.apply(decoder_params, timestep.observation, embeddings)

        logits -= 1e6 * timestep.observation.action_mask

        random_key, subkey = jax.random.split(random_key)
        action = rlax.greedy().sample(subkey, logits)

        state, timestep = environment.step(state, action)

        chosen_actions.append(action)
        print(f"Took action {action} at step {i}.")

    """Now, the goal is to create a memory and to use the memory decoder to 
    act in the environment and to fill the memory.
    """

    # create the external memory
    memory_size = 10

    key_dim = 128
    value_dim = 129

    # create dummy data to init the memory
    dummy_key = jnp.zeros((key_dim,), dtype=jnp.float32)
    dummy_value = jnp.zeros((value_dim,), dtype=jnp.float32)
    memory = ExternalMemory.create(memory_size, dummy_key, dummy_value)

    # put fake data in the memory
    batch_size = 8

    # create some random data to be added
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.uniform(subkey, shape=(batch_size, key_dim))

    random_key, subkey = jax.random.split(random_key)
    values = jax.random.uniform(subkey, shape=(batch_size, value_dim))

    # insert the data
    memory = update_memory(memory, keys, values)

    print("Memory : ", memory)

    # create a decoder
    random_key, subkey = jax.random.split(random_key)
    memory_decoder_params = memory_decoder_fn.init(
        subkey, _dummy_obs, _dummy_embeddings, memory
    )

    # jax.tree_util.tree_map(
    #     lambda x, y: print(f"{x.shape} | {y.shape}"),
    #     decoder_params,
    #     memory_decoder_params,
    # )

    decoder_structure = jax.tree_util.tree_structure(memory_decoder_params)

    print("Decoder structure: ", decoder_structure)

    # first: merge the params - normal decoder last to override!
    memory_decoder_params = hk.data_structures.merge(
        memory_decoder_params, decoder_params
    )

    # decrease the typical scale of the memory decoder params
    # but only for the memory part
    decrease_scale_fn = (
        lambda module_name, name, value: 0.1 * value
        if "memory" in module_name
        else value
    )

    memory_decoder_params = hk.data_structures.map(
        decrease_scale_fn, memory_decoder_params
    )

    # jax.tree_util.tree_map(
    #     lambda x, y: print(f"{x.shape} | {y.shape}"),
    #     decoder_params,
    #     memory_decoder_params,
    # )

    # # retrieve a particular matrix
    # layer_1_name = "decoder/mha_dec/query"
    # layer_2_name = "decoder/mha_dec/value"

    # new_params = {}

    # for layer_name in [layer_1_name, layer_2_name]:
    #     layer_matrix = decoder_params[layer_name]["w"]

    #     print("layer matrix shape: ", layer_matrix.shape)
    #     print("layer matrix: ", layer_matrix)

    #     offset_with_zeros = False
    #     if offset_with_zeros:
    #         offset_matrix = jnp.zeros(shape=(num_behaviors, layer_matrix.shape[-1]))
    #     else:
    #         random_key, subkey = jax.random.split(random_key)
    #         offset_matrix = 10.0 * jax.random.uniform(
    #             subkey, shape=(num_behaviors, layer_matrix.shape[-1])
    #         )

    #     print("Offset matrix: ", offset_matrix.shape)

    #     new_layer_matrix = jnp.concatenate([layer_matrix, offset_matrix], axis=0)

    #     print("New layer matrix shape: ", new_layer_matrix.shape)
    #     print("New layer matrix: ", new_layer_matrix)

    #     new_params[layer_name] = {"w": new_layer_matrix}

    # conditioned_decoder_params = hk.data_structures.merge(
    #     conditioned_decoder_params, new_params
    # )

    print("Memory decoder params: ", memory_decoder_params)

    """Let's try to act in the environment with our new memory decoder
    """

    # reset
    state, timestep = environment.reset_from_state(problem, start_position)

    print("Do steps in the environment")
    new_chosen_actions = []
    for i in range(20):
        logits, _ = memory_decoder_fn.apply(
            memory_decoder_params,
            timestep.observation,
            embeddings,
            memory,
        )
        logits -= 1e30 * timestep.observation.action_mask

        random_key, subkey = jax.random.split(random_key)
        action = rlax.greedy().sample(subkey, logits)

        state, timestep = environment.step(state, action)

        new_chosen_actions.append(action)
        print(f"Took action {action} at step {i}.")

    print("Action chosen by the decoder: ", chosen_actions)
    print("Action chosen by the memory decoder: ", new_chosen_actions)
