import copy
import functools
import os
import time
from typing import Any, Dict, Tuple
import subprocess

import chex
import flax
import hydra
import jax
import jax.numpy as jnp
import optax
from colorama import Fore, Style
from flax.core.frozen_dict import FrozenDict
from jumanji.env import Environment
from jumanji.env import specs
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint

from stoix.utils.flatten_util import ravel_pytree
from stoix.utils.free_step import get_diff_gradient
import stoix.utils.running_statistics as running_statistics
from stoix.base_types import ExperimentOutput, LearnerFn, RecActorApply, RecCriticApply
from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn
from stoix.networks.base import RecurrentActor, RecurrentCritic, ScannedRNN
from stoix.systems.ppo.ppo_types import (
    ActorCriticOptStates,
    ActorCriticOuterOptStates,
    ActorCriticParams,
    HiddenStates,
    RNNLearnerStateObsNorm,
    RNNPPOTransition,
)
from stoix.utils import make_env as environments
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from stoix.utils.logger import LogEvent, StoixLogger, get_logger_path
from stoix.utils.loss import (
    clipped_value_loss,
    dpo_loss,
    ppo_penalty_loss,
    unclipped_value_loss,
    ppo_clip_loss,
)
from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation
from stoix.utils.total_timestep_checker import check_total_timesteps
from stoix.utils.training import make_learning_rate
from stoix.wrappers.episode_metrics import get_final_step_metrics
import numpy as np

def get_learner_fn(
    env: Environment,
    apply_fns: Tuple[RecActorApply, RecCriticApply],
    update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn],
    config: DictConfig,
) -> LearnerFn[RNNLearnerStateObsNorm]:
    """Get the learner function."""

    # Get apply and update functions for actor and critic networks.
    actor_apply_fn, critic_apply_fn = apply_fns
    actor_update_fn, critic_update_fn, outer_update_fn, obs_norm_update_fn, free_step_update_fn = update_fns

    def _update_step(learner_state: RNNLearnerStateObsNorm, _: Any) -> Tuple[RNNLearnerStateObsNorm, Tuple]:
        """A single update of the network.

        This function steps the environment and records the trajectory batch for
        training. It then calculates advantages and targets based on the recorded
        trajectory and updates the actor and critic networks based on the calculated
        losses.

        Args:
            learner_state (NamedTuple):
                - params (ActorCriticParams): The current model parameters.
                - opt_states (OptStates): The current optimizer states.
                - key (PRNGKey): The random number generator state.
                - env_state (State): The environment state.
                - last_timestep (TimeStep): The last timestep in the current trajectory.
                - dones (bool): Whether the last timestep was a terminal state.
                - hstates (HiddenStates): The current hidden states of the RNN.
            _ (Any): The current metrics info.
        """

        def _env_step(
            learner_state: RNNLearnerStateObsNorm, _: Any
        ) -> Tuple[RNNLearnerStateObsNorm, RNNPPOTransition]:
            """Step the environment."""
            (
                params,
                opt_states,
                obs_norm_params,
                key,
                env_state,
                last_timestep,
                last_done,
                last_truncated,
                hstates,
                fs_up,
                out_grad
            ) = learner_state

            key, policy_key = jax.random.split(key)

            # Add a batch dimension to the observation.
            batched_observation = jax.tree_util.tree_map(
                lambda x: x[jnp.newaxis, :], last_timestep.observation
            )
            ac_in = (
                batched_observation,
                last_done[jnp.newaxis, :],
            )

            # Run the network.
            policy_hidden_state, actor_policy = actor_apply_fn(
                params.actor_params, obs_norm_params, hstates.policy_hidden_state, ac_in,
            )
            critic_hidden_state, value = critic_apply_fn(
                params.critic_params, obs_norm_params, hstates.critic_hidden_state, ac_in,
            )

            # Sample action from the policy and squeeze out the batch dimension.
            action = actor_policy.sample(seed=policy_key)
            log_prob = actor_policy.log_prob(action)
            value, action, log_prob = (
                value.squeeze(0),
                action.squeeze(0),
                log_prob.squeeze(0),
            )

            # Step the environment.
            env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

            # log episode return and length
            done = (timestep.discount == 0.0).reshape(-1)
            truncated = (timestep.last() & (timestep.discount != 0.0)).reshape(-1)
            info = timestep.extras["episode_metrics"]

            hstates = HiddenStates(policy_hidden_state, critic_hidden_state)
            transition = RNNPPOTransition(
                last_done,
                last_truncated,
                action,
                value,
                timestep.reward,
                log_prob,
                last_timestep.observation,
                hstates,
                info,
            )
            learner_state = RNNLearnerStateObsNorm(
                params,
                opt_states,
                obs_norm_params,
                key,
                env_state,
                timestep,
                done,
                truncated,
                hstates,
                fs_up,
                out_grad
            )
            return learner_state, transition

        # INITIALISE RNN STATE
        initial_hstates = learner_state.hstates

        # STEP ENVIRONMENT FOR ROLLOUT LENGTH
        learner_state, traj_batch = jax.lax.scan(
            _env_step, learner_state, None, config.system.rollout_length
        )

        # CALCULATE ADVANTAGE
        (
            params,
            opt_states,
            obs_norm_params,
            key,
            env_state,
            last_timestep,
            last_done,
            last_truncated,
            hstates,
            free_step_updates,
            prev_out_grad
        ) = learner_state

        # UPDATE OBSERVATION NORMALIZATION PARAMETERS
        obs_norm_params = obs_norm_update_fn(obs_norm_params, traj_batch.obs)

        # Add a batch dimension to the observation.
        batched_last_observation = jax.tree_util.tree_map(
            lambda x: x[jnp.newaxis, :], last_timestep.observation
        )
        ac_in = (
            batched_last_observation,
            last_done[jnp.newaxis, :],
        )

        # Run the network.
        _, last_val = critic_apply_fn(params.critic_params, obs_norm_params, hstates.critic_hidden_state, ac_in)
        # Squeeze out the batch dimension and mask out the value of terminal states.
        last_val = last_val.squeeze(0)
        last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val)

        r_t = traj_batch.reward * config.system.reward_scaling
        v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0)
        d_t = 1.0 - traj_batch.done.astype(jnp.float32)
        d_t = (d_t * config.system.gamma).astype(jnp.float32)
        advantages, targets = batch_truncated_generalized_advantage_estimation(
            r_t,
            d_t,
            config.system.gae_lambda,
            v_t,
            time_major=True,
            standardize_advantages=config.system.standardize_advantages,
            truncation_flags=traj_batch.truncated,
        )

        # COPY BEHAVIOUR PARAMS
        behaviour_params = params

        # APPLY FREE STEP UPDATES
        params = optax.apply_updates(params, free_step_updates)

        def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
            """Update the network for a single epoch."""

            def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
                """Update the network for a single minibatch."""

                params, opt_states, rng_key = train_state
                (
                    traj_batch,
                    advantages,
                    targets,
                ) = batch_info

                def _actor_loss_fn(
                    actor_params: FrozenDict,
                    traj_batch: RNNPPOTransition,
                    gae: chex.Array,
                    rng_key: chex.PRNGKey,
                ) -> Tuple:
                    """Calculate the actor loss."""
                    # RERUN NETWORK

                    obs_and_done = (traj_batch.obs, traj_batch.done)
                    policy_hidden_state = jax.tree_util.tree_map(
                        lambda x: x[0], traj_batch.hstates.policy_hidden_state
                    )
                    _, actor_policy = actor_apply_fn(
                        actor_params, obs_norm_params, policy_hidden_state, obs_and_done
                    )
                    log_prob = actor_policy.log_prob(traj_batch.action)

                    # CALCULATE ACTOR LOSS
                    if config.system.loss_actor_type == "ppo-penalty":
                        behaviour_policy = actor_apply_fn(
                            behaviour_params.actor_params, obs_norm_params, traj_batch.obs
                        )
                        loss_actor, _ = ppo_penalty_loss(
                            log_prob,
                            traj_batch.log_prob,
                            gae,
                            config.system.kl_penalty_coef,
                            actor_policy,
                            behaviour_policy,
                        )
                    elif config.system.loss_actor_type == "ppo-clip":
                        loss_actor = ppo_clip_loss(
                            log_prob, traj_batch.log_prob, gae, config.system.clip_eps
                        )
                    elif config.system.loss_actor_type == "dpo":
                        loss_actor = dpo_loss(
                            log_prob,
                            traj_batch.log_prob,
                            gae,
                            config.system.alpha,
                            config.system.beta,
                        )

                    entropy = actor_policy.entropy(seed=rng_key).mean()

                    total_loss = loss_actor - config.system.ent_coef * entropy
                    loss_info = {
                        "actor_loss": loss_actor,
                        "entropy": entropy,
                    }
                    return total_loss, loss_info

                def _critic_loss_fn(
                    critic_params: FrozenDict,
                    traj_batch: RNNPPOTransition,
                    targets: chex.Array,
                ) -> Tuple:
                    """Calculate the critic loss."""
                    # RERUN NETWORK
                    obs_and_done = (traj_batch.obs, traj_batch.done)
                    critic_hidden_state = jax.tree_util.tree_map(
                        lambda x: x[0], traj_batch.hstates.critic_hidden_state
                    )
                    _, value = critic_apply_fn(critic_params, obs_norm_params, critic_hidden_state, obs_and_done)

                    # CALCULATE VALUE LOSS
                    if config.system.loss_critic_type == "clip":
                        value_loss = clipped_value_loss(
                            value, traj_batch.value, targets, config.system.clip_eps
                        )
                    elif config.system.loss_critic_type == "unclip":
                        value_loss = unclipped_value_loss(
                            value,
                            targets,
                        )

                    total_loss = config.system.vf_coef * value_loss
                    loss_info = {
                        "value_loss": value_loss,
                    }
                    return total_loss, loss_info
                
                rng_key, actor_key = jax.random.split(rng_key)

                # CALCULATE ACTOR LOSS
                actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True)
                actor_grads, actor_loss_info = actor_grad_fn(
                    params.actor_params, traj_batch, advantages, actor_key
                )

                # CALCULATE CRITIC LOSS
                critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True)
                critic_grads, critic_loss_info = critic_grad_fn(
                    params.critic_params, traj_batch, targets
                )

                # Compute the parallel mean (pmean) over the batch.
                # This calculation is inspired by the Anakin architecture demo notebook.
                # available at https://tinyurl.com/26tdzs5x
                # This pmean could be a regular mean as the batch axis is on the same device.
                actor_grads, actor_loss_info = jax.lax.pmean(
                    (actor_grads, actor_loss_info), axis_name="batch"
                )
                # pmean over devices.
                actor_grads, actor_loss_info = jax.lax.pmean(
                    (actor_grads, actor_loss_info), axis_name="device"
                )

                critic_grads, critic_loss_info = jax.lax.pmean(
                    (critic_grads, critic_loss_info), axis_name="batch"
                )
                # pmean over devices.
                critic_grads, critic_loss_info = jax.lax.pmean(
                    (critic_grads, critic_loss_info), axis_name="device"
                )

                # UPDATE ACTOR PARAMS AND OPTIMISER STATE
                actor_updates, actor_new_opt_state = actor_update_fn(
                    actor_grads, opt_states.actor_opt_state
                )
                actor_new_params = optax.apply_updates(params.actor_params, actor_updates)

                # UPDATE CRITIC PARAMS AND OPTIMISER STATE
                critic_updates, critic_new_opt_state = critic_update_fn(
                    critic_grads, opt_states.critic_opt_state
                )
                critic_new_params = optax.apply_updates(params.critic_params, critic_updates)

                # PACK NEW PARAMS AND OPTIMISER STATE
                new_params = ActorCriticParams(actor_new_params, critic_new_params)
                new_opt_state = opt_states._replace(
                    actor_opt_state=actor_new_opt_state, critic_opt_state=critic_new_opt_state
                )

                # PACK LOSS INFO
                loss_info = {
                    **actor_loss_info,
                    **critic_loss_info,
                }

                return (new_params, new_opt_state, rng_key), loss_info

            (
                params,
                opt_states,
                init_hstates,
                traj_batch,
                advantages,
                targets,
                key,
            ) = update_state
            key, shuffle_key, loss_key = jax.random.split(key, 3)

            # SHUFFLE MINIBATCHES
            batch = (traj_batch, advantages, targets)
            num_recurrent_chunks = (
                config.system.rollout_length // config.system.recurrent_chunk_size
            )
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape(
                    config.system.recurrent_chunk_size,
                    config.arch.num_envs * num_recurrent_chunks,
                    *x.shape[2:],
                ),
                batch,
            )
            permutation = jax.random.permutation(
                shuffle_key, config.arch.num_envs * num_recurrent_chunks
            )
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=1), batch
            )
            reshaped_batch = jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:])
                ),
                shuffled_batch,
            )
            minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch)

            # UPDATE MINIBATCHES
            (params, opt_states, loss_key), loss_info = jax.lax.scan(
                _update_minibatch, (params, opt_states, loss_key), minibatches
            )

            update_state = (
                params,
                opt_states,
                init_hstates,
                traj_batch,
                advantages,
                targets,
                key,
            )
            return update_state, loss_info

        init_hstates = jax.tree_util.tree_map(lambda x: x[None, :], initial_hstates)
        update_state = (
            params,
            opt_states,
            init_hstates,
            traj_batch,
            advantages,
            targets,
            key,
        )

        # UPDATE EPOCHS
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, config.system.epochs
        )

        latest_params, opt_states, _, traj_batch, advantages, targets, key = update_state

        # OUTER STEP
        outer_grads = get_diff_gradient(latest_params, behaviour_params)
        outer_updates, outer_new_opt_state = outer_update_fn(
            outer_grads, opt_states.outer_opt_state
        )
        # Apply the updates to the parameters
        new_params = optax.apply_updates(behaviour_params, outer_updates)

        # UPDATE FREE STEP OPTIMISER AND GET THE UPDATES WE WILL APPLY 
        # BEFORE THE NEXT PPO ITERATION
        free_step_updates, free_step_new_opt_state = free_step_update_fn(
            outer_grads, opt_states.free_step_opt_state
        )
        
        opt_states = opt_states._replace(outer_opt_state=outer_new_opt_state, free_step_opt_state=free_step_new_opt_state)
        
        # Get Outer Gradient Metrics
        prev_out_grad_vector,_ = ravel_pytree(prev_out_grad)
        new_out_grad_vector,_ = ravel_pytree(outer_grads)

        def cosine_similarity(x, y):
            return jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y) + 1e-8)
        cosine_sim = cosine_similarity(new_out_grad_vector, prev_out_grad_vector)
        out_grad_norm = optax.global_norm(outer_grads)

        loss_info["cosine_similarity"] = cosine_sim
        loss_info["outer_grad_norm"] = out_grad_norm

        learner_state = RNNLearnerStateObsNorm(
            new_params,
            opt_states,
            obs_norm_params,
            key,
            env_state,
            last_timestep,
            last_done,
            last_truncated,
            hstates,
            free_step_updates,
            outer_grads
        )
        metric = traj_batch.info
        return learner_state, (metric, loss_info)

    def learner_fn(learner_state: RNNLearnerStateObsNorm) -> ExperimentOutput[RNNLearnerStateObsNorm]:
        """Learner function.

        This function represents the learner, it updates the network parameters
        by iteratively applying the `_update_step` function for a fixed number of
        updates. The `_update_step` function is vectorized over a batch of inputs.

        Args:
            learner_state (NamedTuple):
                - params (ActorCriticParams): The initial model parameters.
                - opt_states (OptStates): The initial optimizer states.
                - key (chex.PRNGKey): The random number generator state.
                - env_state (LogEnvState): The environment state.
                - timesteps (TimeStep): The initial timestep in the initial trajectory.
                - dones (bool): Whether the initial timestep was a terminal state.
                - hstateS (HiddenStates): The initial hidden states of the RNN.
        """

        batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch")

        learner_state, (episode_info, loss_info) = jax.lax.scan(
            batched_update_step, learner_state, None, config.arch.num_updates_per_eval
        )
        return ExperimentOutput(
            learner_state=learner_state,
            episode_metrics=episode_info,
            train_metrics=loss_info,
        )

    return learner_fn


def learner_setup(
    env: Environment, keys: chex.Array, config: DictConfig
) -> Tuple[LearnerFn[RNNLearnerStateObsNorm], RecurrentActor, ScannedRNN, RNNLearnerStateObsNorm]:
    """Initialise learner_fn, network, optimiser, environment and states."""
    # Get available TPU cores.
    n_devices = len(jax.devices())

    # Get number/dimension of actions.
    action_spec = env.action_spec()
    if isinstance(action_spec, specs.DiscreteArray):
        num_actions = int(action_spec.num_values)
        network_kwargs = {}
    else:
        num_actions = int(env.action_spec().shape[-1])
        config.system.action_minimum = float(env.action_spec().minimum)
        config.system.action_maximum = float(env.action_spec().maximum)
        network_kwargs = {"minimum":config.system.action_minimum, "maximum":config.system.action_maximum,}
    
    config.system.action_dim = num_actions

    # PRNG keys.
    key, actor_net_key, critic_net_key = keys

    # Define network and optimisers.
    actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
    actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
    actor_action_head = hydra.utils.instantiate(
        config.network.actor_network.action_head, action_dim=num_actions, **network_kwargs
    )
    critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
    critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
    critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head)

    actor_network = RecurrentActor(
        pre_torso=actor_pre_torso,
        hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim,
        cell_type=config.network.critic_network.rnn_layer.cell_type,
        post_torso=actor_post_torso,
        action_head=actor_action_head,
    )
    critic_network = RecurrentCritic(
        pre_torso=critic_pre_torso,
        hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim,
        cell_type=config.network.critic_network.rnn_layer.cell_type,
        post_torso=critic_post_torso,
        critic_head=critic_head,
    )
    actor_rnn = ScannedRNN(
        hidden_state_dim=config.network.actor_network.rnn_layer.hidden_state_dim,
        cell_type=config.network.actor_network.rnn_layer.cell_type,
    )
    critic_rnn = ScannedRNN(
        hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim,
        cell_type=config.network.critic_network.rnn_layer.cell_type,
    )

    actor_lr = make_learning_rate(
        config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches
    )
    critic_lr = make_learning_rate(
        config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches
    )

    # These minus scaled gradients
    actor_optim = optax.chain(
        optax.clip_by_global_norm(config.system.max_grad_norm),
        optax.adam(actor_lr, eps=1e-5),
    )
    critic_optim = optax.chain(
        optax.clip_by_global_norm(config.system.max_grad_norm),
        optax.adam(critic_lr, eps=1e-5),
    )

    if type(config.system.outer_optimizer.learning_rate) is not float:
        config.system.outer_optimizer.learning_rate.transition_steps = config.arch.num_updates

    config.system.free_step_learning_rate.transition_steps = config.arch.num_updates

    if "hb" in config.sweep_name or 'nest' in config.sweep_name or 'ema' in config.sweep_name:
        momentum_schedule = optax.linear_schedule(
            config.system.outer_optimizer.init_momentum,
            config.system.outer_optimizer.end_momentum,
            config.arch.num_updates
            )

        if 'ema' in config.sweep_name:

            if 'sched' in config.sweep_name:
                raise NotImplementedError("EMA not implemented with sched")

            outer_optimizer = optax.chain(
                optax.inject_hyperparams(optax.ema)(momentum_schedule),
                optax.scale_by_learning_rate(config.system.outer_optimizer.learning_rate),
            )

        else:

            if not 'sched' in config.sweep_name:
                # lr will be constant if hb_warm or a schedule if hb_sched_warm
                outer_optimizer = optax.inject_hyperparams(optax.sgd)(learning_rate=config.system.outer_optimizer.learning_rate, momentum=momentum_schedule,
                nesterov=config.system.outer_optimizer.nesterov
                )
            else:

                learning_rate_schedule = optax.cosine_onecycle_schedule(
                    transition_steps=config.arch.num_updates,
                    peak_value=config.system.outer_optimizer.learning_rate.peak_value,
                    pct_start=config.system.outer_optimizer.learning_rate.pct_start,
                    div_factor=config.system.outer_optimizer.learning_rate.div_factor,
                    final_div_factor=config.system.outer_optimizer.learning_rate.final_div_factor
                )

                outer_optimizer = optax.inject_hyperparams(optax.sgd)(learning_rate=learning_rate_schedule, momentum=momentum_schedule,
                nesterov=config.system.outer_optimizer.nesterov
                )
    
    else:
        outer_optimizer = hydra.utils.instantiate(config.system.outer_optimizer)

    # This adds scaled gradients
    outer_optim = optax.chain(
        optax.scale(-1),
        outer_optimizer,
    )

    free_step_optim = optax.chain(
        optax.scale(-1),
        optax.ema(config.system.free_step_momentum),
        optax.inject_hyperparams(optax.scale_by_learning_rate)(learning_rate=hydra.utils.instantiate(config.system.free_step_learning_rate))
        )
    
    # Initialise observation
    init_obs = env.observation_spec().generate_value()
    
    # Initialise observation normalization parameters
    obs_norm_params = running_statistics.init_state(
        init_obs
    )

    init_obs = jax.tree_util.tree_map(
        lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
        init_obs,
    )
    init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs)
    init_done = jnp.zeros((1, config.arch.num_envs), dtype=bool)
    init_x = (init_obs, init_done)

    # Define observation normalisation functions.
    obs_norm_update_fn = functools.partial(
        running_statistics.update, pmap_axis_name=["batch", "device"]
    )
    if config.system.normalize_observations:
        obs_norm_apply_fn = running_statistics.normalize
    else:
        obs_norm_apply_fn = lambda x, _: x

    # Initialise hidden states.
    init_policy_hstate = actor_rnn.initialize_carry(config.arch.num_envs)
    init_critic_hstate = critic_rnn.initialize_carry(config.arch.num_envs)

    # initialise params and optimiser state.
    actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x)
    actor_opt_state = actor_optim.init(actor_params)
    critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x)
    critic_opt_state = critic_optim.init(critic_params)

    # Pack params and initial states.
    params = ActorCriticParams(actor_params, critic_params)
    hstates = HiddenStates(init_policy_hstate, init_critic_hstate)

    # Initialise outer optimiser state.
    outer_optim_state = outer_optim.init(params)
    free_step_optim_state = free_step_optim.init(params)

    # Here we define the network apply functions with normalization.
    actor_network_apply_fn = lambda act_params, norm_params, hstates, ac_in: actor_network.apply(
        act_params, hstates, (obs_norm_apply_fn(ac_in[0], jax.lax.stop_gradient(norm_params)), ac_in[1])
    )
    critic_network_apply_fn = lambda crit_params, norm_params, hstates, ac_in: critic_network.apply(
        crit_params, hstates, (obs_norm_apply_fn(ac_in[0], jax.lax.stop_gradient(norm_params)), ac_in[1])
    )

    # Pack apply and update functions.
    apply_fns = (actor_network_apply_fn, critic_network_apply_fn)
    update_fns = (actor_optim.update, critic_optim.update, outer_optim.update, obs_norm_update_fn, free_step_optim.update)

    # Get batched iterated update and replicate it to pmap it over cores.
    learn = get_learner_fn(env, apply_fns, update_fns, config)
    learn = jax.pmap(learn, axis_name="device")

    # Initialise environment states and timesteps: across devices and batches.
    key, *env_keys = jax.random.split(
        key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1
    )
    env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
        jnp.stack(env_keys),
    )
    reshape_states = lambda x: x.reshape(
        (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:]
    )
    # (devices, update batch size, num_envs, ...)
    env_states = jax.tree_util.tree_map(reshape_states, env_states)
    timesteps = jax.tree_util.tree_map(reshape_states, timesteps)

    # Load model from checkpoint if specified.
    if config.logger.checkpointing.load_model:
        raise NotImplementedError("Loading model from checkpoint is not implemented.")
        #loaded_checkpoint = Checkpointer(
        #    model_name=config.system.system_name,
        #    **config.logger.checkpointing.load_args,  # Other checkpoint args
        #)
        ## Restore the learner state from the checkpoint
        #restored_params, _ = loaded_checkpoint.restore_params()
        ## Update the params
        #params = restored_params

    # Define params to be replicated across devices and batches.
    dones = jnp.zeros(
        (config.arch.num_envs,),
        dtype=bool,
    )
    truncated = jnp.zeros(
        (config.arch.num_envs,),
        dtype=bool,
    )
    key, step_key = jax.random.split(key)
    step_keys = jax.random.split(step_key, n_devices * config.system.update_batch_size)
    reshape_keys = lambda x: x.reshape((n_devices, config.system.update_batch_size) + x.shape[1:])
    step_keys = reshape_keys(jnp.stack(step_keys))
    opt_states = ActorCriticOuterOptStates(actor_opt_state, critic_opt_state, outer_optim_state, free_step_optim_state)
    replicate_learner = (params, opt_states, obs_norm_params, hstates, dones, truncated)

    # Duplicate learner for update_batch_size.
    broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape)
    replicate_learner = jax.tree_util.tree_map(broadcast, replicate_learner)

    # Duplicate learner across devices.
    replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

    # Initialise learner state.
    params, opt_states, obs_norm_params, hstates, dones, truncated = replicate_learner
    init_learner_state = RNNLearnerStateObsNorm(
        params=params,
        opt_states=opt_states,
        obs_norm_params=obs_norm_params,
        key=step_keys,
        env_state=env_states,
        timestep=timesteps,
        dones=dones,
        truncated=truncated,
        hstates=hstates,
        free_step_updates=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params),
        outer_gradient=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params)
    )
    return learn, actor_network_apply_fn, actor_rnn, init_learner_state


def run_experiment(_config: DictConfig) -> float:
    """Runs experiment."""
    config = copy.deepcopy(_config)


    if not config.system.run_outer_ppo:
        if config.system.free_step_learning_rate.peak_value != 0.0:
            raise ValueError("Free step learning rate peak value must be 0.0 if run_outer_ppo is False")
        if config.system.outer_optimizer.momentum is not None:
            if config.system.outer_optimizer.init_momentum != 0.0:
                raise ValueError("Outer optimizer learning init_momentum must be 0.0 if run_outer_ppo is False")
        if hasattr(config.system.outer_optimizer.learning_rate, 'peak_value'):
            if config.system.outer_optimizer.learning_rate.peak_value != 1.0:
                raise ValueError("Outer optimizer learning rate peak value must be 1.0 if run_outer_ppo is False")
            if config.system.outer_optimizer.learning_rate.div_factor != 1.0:
                raise ValueError("Outer optimizer learning rate div_factor must be 1.0 if run_outer_ppo is False")
            if config.system.outer_optimizer.learning_rate.final_div_factor != 1.0:
                raise ValueError("Outer optimizer learning rate final_div_factor must be 1.0 if run_outer_ppo is False")
        elif config.system.outer_optimizer.learning_rate != 1.0:
            raise ValueError("Outer optimizer learning rate must be 1.0 if run_outer_ppo is False")


    # Calculate total timesteps.
    n_devices = len(jax.devices())
    config.num_devices = n_devices
    config = check_total_timesteps(config)
    assert (
        config.arch.num_updates > config.arch.num_evaluation
    ), "Number of updates per evaluation must be less than total number of updates."

    # Set recurrent chunk size.
    if config.system.recurrent_chunk_size is None:
        config.system.recurrent_chunk_size = config.system.rollout_length
    else:
        assert (
            config.system.rollout_length % config.system.recurrent_chunk_size == 0
        ), "Rollout length must be divisible by recurrent chunk size."

    # Create the environments for train and eval.
    env, eval_env = environments.make(config=config)

    # PRNG keys.
    key, key_e, actor_net_key, critic_net_key = jax.random.split(
        jax.random.PRNGKey(config.arch.seed), num=4
    )

    # Setup learner.
    learn, actor_network_apply_fn, actor_rnn, learner_state = learner_setup(
        env, (key, actor_net_key, critic_net_key), config
    )

    # Setup evaluator.
    evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup(
        eval_env=eval_env,
        key_e=key_e,
        eval_act_fn=get_rec_distribution_act_fn(config, actor_network_apply_fn),
        params=learner_state.params.actor_params,
        config=config,
        use_recurrent_net=True,
        scanned_rnn=actor_rnn,
    )

    # Calculate number of updates per evaluation.
    config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
    steps_per_rollout = (
        n_devices
        * config.arch.num_updates_per_eval
        * config.system.rollout_length
        * config.system.update_batch_size
        * config.arch.num_envs
    )

    # Logger setup
    logger = StoixLogger(config)
    cfg: Dict = OmegaConf.to_container(config, resolve=True)
    cfg["arch"]["devices"] = jax.devices()
    pprint(cfg)

    # Set up checkpointer
    save_checkpoint = config.logger.checkpointing.save_model
    if save_checkpoint:
        checkpointer = Checkpointer(
            metadata=config,  # Save all config as metadata in the checkpoint
            model_name=config.system.system_name,
            **config.logger.checkpointing.save_args,  # Checkpoint args
        )

    # Run experiment for a total number of evaluations.
    max_episode_return = jnp.float32(-1e7)
    best_params = unreplicate_batch_dim(learner_state.params.actor_params)
    best_obs_norm_params = unreplicate_batch_dim(learner_state.obs_norm_params)
    # Perform initial evaluation.
    start_time = time.time()
    trained_params = unreplicate_batch_dim(learner_state.params.actor_params)
    obs_norm_params = unreplicate_batch_dim(learner_state.obs_norm_params)
    key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
    eval_keys = jnp.stack(eval_keys)
    eval_keys = eval_keys.reshape(n_devices, -1)

    # Evaluate.
    evaluator_output = evaluator(trained_params, obs_norm_params, eval_keys)
    jax.block_until_ready(evaluator_output)

    # Log the results of the evaluation.
    elapsed_time = time.time() - start_time
    episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"])

    steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"]))
    evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
    logger.log(evaluator_output.episode_metrics, 0, 0, LogEvent.EVAL)
    for eval_step in range(1, config.arch.num_evaluation+1):
        # Train.
        start_time = time.time()

        learner_output = learn(learner_state)
        jax.block_until_ready(learner_output)

        # Log the results of the training.
        elapsed_time = time.time() - start_time
        t = int(steps_per_rollout * (eval_step + 1))
        episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics)
        episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time

        # Separately log timesteps, actoring metrics and training metrics.
        logger.log({"timestep": t}, t, eval_step, LogEvent.MISC)
        if ep_completed:  # only log episode metrics if an episode was completed in the rollout.
            logger.log(episode_metrics, t, eval_step, LogEvent.ACT)
        logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN)
        
        # Log cosine sim
        cosine_sim = learner_output.train_metrics["cosine_similarity"].squeeze(axis=-1)
        grad_norm = learner_output.train_metrics["outer_grad_norm"].squeeze(axis=-1)
        chex.assert_shape(cosine_sim, (config.num_devices, config.arch.num_updates_per_eval,))
        chex.assert_shape(grad_norm, (config.num_devices, config.arch.num_updates_per_eval,))
        # Take first device since they are all the same
        cosine_sim = cosine_sim[0]
        grad_norm = grad_norm[0]
        unique_token = logger.unique_token
        json_exp_path = get_logger_path(config, "json")
        path = os.path.join(
        config.logger.base_exp_path, f"{json_exp_path}/{unique_token}/cosine_sim.txt"
    )
        with open(path, "a") as f:
            f.write(f"{np.array(cosine_sim).tolist()} {np.array(grad_norm).tolist()}\n")

        # check for any nans in the actor loss, stop the experiment if there are any
        if jnp.isnan(learner_output.train_metrics['actor_loss']).sum():
            return float('nan')

        # Prepare for evaluation.
        start_time = time.time()
        trained_params = unreplicate_batch_dim(
            learner_output.learner_state.params.actor_params
        )  # Select only actor params
        obs_norm_params = unreplicate_batch_dim(learner_output.learner_state.obs_norm_params)
        key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
        eval_keys = jnp.stack(eval_keys)
        eval_keys = eval_keys.reshape(n_devices, -1)

        # Evaluate.
        evaluator_output = evaluator(trained_params, obs_norm_params, eval_keys)
        jax.block_until_ready(evaluator_output)

        # Log the results of the evaluation.
        elapsed_time = time.time() - start_time
        episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"])

        steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"]))
        evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
        logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL)

        if save_checkpoint:
            # Save checkpoint of learner state
            checkpointer.save(
                timestep=int(steps_per_rollout * (eval_step + 1)),
                unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state),
                episode_return=episode_return,
            )

        if config.arch.absolute_metric and max_episode_return <= episode_return:
            best_params = copy.deepcopy(trained_params)
            best_obs_norm_params = copy.deepcopy(obs_norm_params)
            max_episode_return = episode_return

        # Update runner state to continue training.
        learner_state = learner_output.learner_state

    # Measure absolute metric.
    if config.arch.absolute_metric:
        start_time = time.time()

        key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
        eval_keys = jnp.stack(eval_keys)
        eval_keys = eval_keys.reshape(n_devices, -1)

        evaluator_output = absolute_metric_evaluator(best_params, best_obs_norm_params, eval_keys)
        jax.block_until_ready(evaluator_output)

        elapsed_time = time.time() - start_time
        t = int(steps_per_rollout * (eval_step + 1))
        steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"]))
        evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
        logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE)

    # Stop the logger.
    logger.stop()
    # Record the performance for the final evaluation run. If the absolute metric is not
    # calculated, this will be the final evaluation run.
    eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric]))

    if config.logger.kwargs.neptune_mode == 'offline':
        # since we are using several nodes, and in some cases several parallel runs
        # per node we found async (even with larger flush_period) to lead to many 
        # connection issues. this command will sync the offline neptune logs to neptune
        # without maintaining a connection throughout the run.
        subprocess.run(['neptune', 'sync', '-p', config.logger.kwargs.neptune_project])

    return eval_performance


@hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2")
def hydra_entry_point(cfg: DictConfig) -> float:
    """Experiment entry point."""
    # Allow dynamic attributes.
    OmegaConf.set_struct(cfg, False)

    # Run experiment.
    eval_performance = run_experiment(cfg)
    print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")
    return eval_performance


if __name__ == "__main__":
    hydra_entry_point()
