"""In this file, I compare the running time of 
i) a decoder forward pass
ii) a step in a jumanji environment (tsp, cvrp, jssp)

In this file, we'll:
- define the environment
- define the decoder
- jit the step function of the env
- jit the decoder call
- jit the update of CMAES
- jit the sampling from CMAES
- run 5 warmup steps
- take the mean over 50 steps for each of the two functions

While we are here, let's also to the same for the encoder.

"""

import time
from typing import Any

import acme
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import rlax
from typing_extensions import TypeAlias

from experiments.playground.playground_decoders import TSPDummyDecoder
from memento import environments, networks
from memento.memory.external_memory import ExternalMemory, update_memory
from memento.utils.data import generate_zeros_from_spec

Observation: TypeAlias = Any

# num_nearest_neighbors = 5
# memory_size = 10

memory_recall_target = 0.95

key_dim = 128
value_dim = 129

if __name__ == "__main__":

    # create a csv file to put the data into
    import csv

    with open('time_comparisons.csv', 'w', newline='') as csvfile:
        fieldnames = ['memory_size', 'num_nearest_neighbors', 'key_dim', 'value_dim', 'encode_time', 'decode_time', 'memory_decode_time', 'dummy_decode_time', 'step_time']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()


    # for memory_size in [200, 1000, 5000, 10000, 50000, 100000, 200000, 300000, 400000, 500000, 600000]:
    for memory_size in [20000, 50000, 100000, 200000, 300000, 400000, 500000, 600000]:
         for num_nearest_neighbors in [1, 5, 10, 20, 40, 60, 80, 100, 120, 140, 160]:
        #for num_nearest_neighbors in [100]:
            for key_dim in [128]:
                #for value_dim in [10, 20, 50, 100, 129, 150, 200, 300, 500, 1000]:
                for value_dim in [129]:

                    print("--------------------------------------------------")
                    print("--------------------------------------------------")
                    print("--------------------------------------------------")

                    print(f"Memory size: {memory_size}")
                    print(f"Num nearest neighbors: {num_nearest_neighbors}")

                    timing_dictionary = {}
                    timing_dictionary["memory_size"] = memory_size
                    timing_dictionary["num_nearest_neighbors"] = num_nearest_neighbors
                    timing_dictionary["key_dim"] = key_dim
                    timing_dictionary["value_dim"] = value_dim

                    # define an environment
                    environment = environments.MementoTSP(num_cities=100)

                    # retrieve spec of the environment
                    environment_spec = acme.make_environment_spec(environment)

                    # dummy obs used to create the models params
                    _dummy_obs = environment.make_observation(
                        *jax.tree_util.tree_map(
                            generate_zeros_from_spec,
                            environment_spec.observations.generate_value(),
                        )
                    )

                    # init a random key
                    random_key = jax.random.PRNGKey(0)

                    def encoder_fn(problem: chex.Array):
                        # use the class/params given in config to instantiate an encoder
                        encoder = networks.tsp.TSPEncoder(
                            num_layers=6,
                            num_heads=8,
                            key_size=16,
                            expand_factor=4,
                            model_size=128,
                            name="shared_encoder",
                        )
                        return encoder(problem)

                    def decoder_fn(observation: Observation, embeddings: chex.Array):
                        # use the class/params given in config to instantiate decoder
                        decoder = networks.tsp.TSPDecoder(
                            num_heads=8, key_size=16, model_size=128, name="decoder"
                        )
                        return decoder(observation, embeddings)

                    # define the memory decoder
                    def memory_decoder_fn(
                        observation: Observation, embeddings: chex.Array, external_memory: ExternalMemory
                    ):
                        # use the class/params given in config to instantiate decoder
                        decoder = networks.tsp.TSPMemoryDecoder(
                            num_heads=8,
                            key_size=16,
                            model_size=128,
                            name="decoder",
                            num_nearest_neighbors=num_nearest_neighbors,
                            memory_recall_target=memory_recall_target,
                        )
                        return decoder(observation, embeddings, external_memory)

                    # define the dummy decoder
                    def dummy_decoder_fn(
                        observation: Observation, embeddings: chex.Array, dummy_matrix: chex.Array
                    ):
                        # use the class/params given in config to instantiate decoder
                        decoder = TSPDummyDecoder(
                            num_heads=8, key_size=16, model_size=128, name="decoder",
                        )
                        return decoder(observation, embeddings, dummy_matrix)

                    # classic encoder/decoder
                    encoder_fn = hk.without_apply_rng(hk.transform(encoder_fn))
                    decoder_fn = hk.without_apply_rng(hk.transform(decoder_fn))
                    memory_decoder_fn = hk.without_apply_rng(hk.transform(memory_decoder_fn))
                    dummy_decoder_fn = hk.without_apply_rng(hk.transform(dummy_decoder_fn))

                    # create an encoder
                    random_key, subkey = jax.random.split(random_key)
                    encoder_params = encoder_fn.init(subkey, _dummy_obs.problem)

                    encoder_structure = jax.tree_util.tree_structure(encoder_params)

                    # print("Encoder structure: ", encoder_structure)

                    # jax.tree_util.tree_map(lambda x: print(x.shape), encoder_params)

                    # print("Encoder params: ", encoder_params)

                    # get a dummy embedding to init the decoder
                    _dummy_embeddings = encoder_fn.apply(encoder_params, _dummy_obs.problem)

                    # create a decoder
                    random_key, subkey = jax.random.split(random_key)
                    decoder_params = decoder_fn.init(subkey, _dummy_obs, _dummy_embeddings)

                    # create an external memory

                    # create dummy data to init the memory
                    dummy_key = jnp.zeros((key_dim,), dtype=jnp.float32)
                    dummy_value = jnp.zeros((value_dim,), dtype=jnp.float32)
                    memory = ExternalMemory.create(memory_size // environment.get_episode_horizon(), dummy_key, dummy_value)
                    
                    memory = jax.tree_map(
                    lambda x: jnp.repeat(x[None, ...], repeats=environment.get_episode_horizon(), axis=0),
                    memory,
                    )
                    jax.tree_util.tree_map(lambda x: print(x.shape), memory)

                    # put fake data in the memory

                    # create some random data to be added
                    random_key, subkey = jax.random.split(random_key)
                    keys = jax.random.uniform(subkey, shape=(memory_size // environment.get_episode_horizon(), key_dim))
                    keys = jax.tree_map(
                    lambda x: jnp.repeat(x[None, ...], repeats=environment.get_episode_horizon(), axis=0),
                    keys,
                    )

                    random_key, subkey = jax.random.split(random_key)
                    values = jax.random.uniform(subkey, shape=(memory_size // environment.get_episode_horizon(), value_dim))
                    values = jax.tree_map(
                    lambda x: jnp.repeat(x[None, ...], repeats=environment.get_episode_horizon(), axis=0),
                    values,
                    )

                    # insert the data
                    memory = jax.vmap(update_memory)(memory, keys, values)

                    #print("keys : ", keys.shape)

                    # create a memory decoder
                    random_key, subkey = jax.random.split(random_key)
                    memory_decoder_params = memory_decoder_fn.init(
                        subkey, _dummy_obs, _dummy_embeddings, memory
                    )

                    # create a dummy matrix for the decoder (same dim as what will be retrieved)
                    random_key, subkey = jax.random.split(random_key)
                    dummy_matrix = jax.random.uniform(subkey, shape=(num_nearest_neighbors, value_dim))

                    # create a dummy decoder
                    random_key, subkey = jax.random.split(random_key)
                    dummy_decoder_params = dummy_decoder_fn.init(
                        subkey, _dummy_obs, _dummy_embeddings, dummy_matrix
                    )

                    # try to do a few steps in the env - as a sanity check

                    # create the jitted encoding function
                    jitted_encoding_fn = jax.jit(encoder_fn.apply)
                    jitted_decoding_fn = jax.jit(decoder_fn.apply)
                    jitted_memory_decoding_fn = jax.jit(memory_decoder_fn.apply)
                    jitted_dummy_decoding_fn = jax.jit(dummy_decoder_fn.apply)
                    jitted_env_step_fn = jax.jit(environment.step)

                    print("-------------- Start of comparison -----------------")

                    M = 200

                    # warm up all functions on M steps
                    problems = []
                    for i in range(M):
                        random_key, subkey = jax.random.split(random_key)

                        # create a problem
                        problem = environment.generate_problem(subkey, environment.get_problem_size())

                        # encode it
                        embeddings = jitted_encoding_fn(encoder_params, problem)

                        # get  a starting position
                        random_key, subkey = jax.random.split(random_key)
                        start_position = jax.random.randint(
                            subkey,
                            (1,),
                            minval=environment.get_min_start(),
                            maxval=environment.get_max_start() + 1,
                        )[0]

                        # reset the environment
                        state1, timestep1 = environment.reset_from_state(problem, start_position)
                        state2, timestep2 = environment.reset_from_state(problem, start_position)
                        state3, timestep3 = environment.reset_from_state(problem, start_position)

                        # decode to get logits
                        logits1, _ = jitted_decoding_fn(decoder_params, timestep1.observation, embeddings)
                        logits2, _ = jitted_memory_decoding_fn(memory_decoder_params, timestep2.observation, embeddings, memory)
                        logits3, _ = jitted_dummy_decoding_fn(dummy_decoder_params, timestep3.observation, embeddings, dummy_matrix)

                        # get an action from the logits
                        logits1 -= 1e30 * timestep1.observation.action_mask
                        logits2 -= 1e30 * timestep2.observation.action_mask
                        logits3 -= 1e30 * timestep3.observation.action_mask

                        random_key, subkey1, subkey2, subkey3 = jax.random.split(random_key, 4)
                        action1 = rlax.greedy().sample(subkey1, logits1)
                        action2 = rlax.greedy().sample(subkey2, logits2)
                        action3 = rlax.greedy().sample(subkey3, logits3)

                        # take a step in the environment
                        state1, timestep1 = jitted_env_step_fn(state1, action1)
                        state2, timestep2 = jitted_env_step_fn(state2, action2)
                        state3, timestep3 = jitted_env_step_fn(state3, action3)

                    # print("Create a problem")
                    # random_key, subkey = jax.random.split(random_key)
                    # problem = environment.generate_problem(subkey, environment.get_problem_size())

                    N = 500

                    # create N problems
                    problems = []
                    for i in range(N):
                        random_key, subkey = jax.random.split(random_key)
                        problems.append(
                            environment.generate_problem(subkey, environment.get_problem_size())
                        )

                    # encode all problems and time it
                    print("Encode all problems")

                    t0 = time.time()

                    embeddings_list = []

                    for problem in problems:
                        random_key, subkey = jax.random.split(random_key)

                        # add a gaussian noise to the encoder params
                        encoder_params = jax.tree_util.tree_map(
                            lambda x: x + 0.01 * jax.random.normal(subkey, x.shape),
                            encoder_params,
                        )

                        embeddings = jitted_encoding_fn(encoder_params, problem).block_until_ready()

                        embeddings_list.append(embeddings)  # should be negligeable in time

                    t1 = time.time()

                    print(f"Time taken to encode {N} problems: ", t1 - t0)
                    print(f"Time taken to encode 1 problem: ", (t1 - t0) / N)

                    timing_dictionary["encode_time"] = (t1 - t0) / N


                    # get state and timestep

                    states = []
                    timesteps = []

                    for problem in problems:
                        # get  a starting position
                        random_key, subkey = jax.random.split(random_key)
                        start_position = jax.random.randint(
                            subkey,
                            (1,),
                            minval=environment.get_min_start(),
                            maxval=environment.get_max_start() + 1,
                        )[0]

                        # reset the environment
                        state, timestep = environment.reset_from_state(problem, start_position)

                        states.append(state)
                        timesteps.append(timestep)

                    # decode all problems and time it
                    print("Decode all problems - Attention Decoder")

                    t0 = time.time()

                    logits_list = []

                    for embeddings, timestep in zip(embeddings_list, timesteps):
                        # add a gaussian noise to the decoder params
                        random_key, subkey = jax.random.split(random_key)
                        decoder_params = jax.tree_util.tree_map(
                            lambda x: x + 0.01 * jax.random.normal(subkey, x.shape),
                            decoder_params,
                        )

                        logits, _ = jitted_decoding_fn(
                            decoder_params, timestep.observation, embeddings
                        )

                        logits = logits.block_until_ready()

                        logits_list.append(logits)

                    t1 = time.time()

                    print(f"Time taken to decode {N} problems: ", t1 - t0)
                    print(f"Time taken to decode 1 problem: ", (t1 - t0) / N)

                    timing_dictionary["decode_time"] = (t1 - t0) / N

                    # decode all problems and time it
                    print("Decode all problems - Memory Decoder")

                    t0 = time.time()

                    logits_list = []

                    for embeddings, timestep in zip(embeddings_list, timesteps):
                        # add a gaussian noise to the decoder params
                        random_key, subkey = jax.random.split(random_key)
                        memory_decoder_params = jax.tree_util.tree_map(
                            lambda x: x + 0.01 * jax.random.normal(subkey, x.shape),
                            memory_decoder_params,
                        )

                        # add a gaussian noise to the memory keys and values
                        random_key, subkey = jax.random.split(random_key)
                        new_keys = memory.keys + 0.01 * jax.random.normal(subkey, memory.keys.shape)
                        new_values = memory.values + 0.01 * jax.random.normal(subkey, memory.values.shape)

                        memory = memory.replace(keys=new_keys, values=new_values)

                        logits, _ = jitted_memory_decoding_fn(
                            memory_decoder_params, timestep.observation, embeddings, memory
                        )

                        logits = logits.block_until_ready()

                        logits_list.append(logits)

                    t1 = time.time()

                    print(f"Time taken to decode {N} problems with memory: ", t1 - t0)
                    print(f"Time taken to decode 1 problem with memory: ", (t1 - t0) / N)

                    timing_dictionary["memory_decode_time"] = (t1 - t0) / N

                    # decode all problems and time it
                    print("Decode all problems - Dummy Decoder")

                    t0 = time.time()

                    logits_list = []

                    for embeddings, timestep in zip(embeddings_list, timesteps):
                        # add a gaussian noise to the decoder params
                        random_key, subkey = jax.random.split(random_key)
                        dummy_decoder_params = jax.tree_util.tree_map(
                            lambda x: x + 0.01 * jax.random.normal(subkey, x.shape),
                            dummy_decoder_params,
                        )

                        # add a gaussian noise to the dummy matrix
                        random_key, subkey = jax.random.split(random_key)
                        dummy_matrix = dummy_matrix + 0.01 * jax.random.normal(subkey, dummy_matrix.shape)

                        logits, _ = jitted_dummy_decoding_fn(
                            dummy_decoder_params, timestep.observation, embeddings, dummy_matrix
                        )

                        logits = logits.block_until_ready()

                        logits_list.append(logits)

                    t1 = time.time()

                    print(f"Time taken to decode {N} problems with dummy: ", t1 - t0)
                    print(f"Time taken to decode 1 problem with dummy: ", (t1 - t0) / N)

                    timing_dictionary["dummy_decode_time"] = (t1 - t0) / N

                    # step in the environment and time it
                    print("Step in the environment")

                    actions = []
                    for logits in logits_list:
                        random_key, subkey = jax.random.split(random_key)
                        action = rlax.greedy().sample(subkey, logits)

                        actions.append(action)

                    t0 = time.time()

                    for action, state in zip(actions, states):
                        new_state, timestep = jitted_env_step_fn(state, action)

                        jax.tree_util.tree_map(lambda x: x.block_until_ready(), new_state)

                    t1 = time.time()

                    print(f"Time taken to step in the environment {N} times: ", t1 - t0)
                    print(f"Time taken to step in the environment 1 time: ", (t1 - t0) / N)

                    timing_dictionary["step_time"] = (t1 - t0) / N

                    # write the timing dictionary to a csv file
                    with open('time_comparisons.csv', 'a', newline='') as csvfile:
                        fieldnames = ['memory_size', 'num_nearest_neighbors', 'key_dim', 'value_dim', 'encode_time', 'decode_time', 'memory_decode_time', 'dummy_decode_time', 'step_time']
                        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

                        writer.writerow(timing_dictionary)

                

                    print("-------------- End of comparison -----------------")