import copy
import functools
import math
import os
import time
from typing import Any, Dict, Tuple
import subprocess

import chex
import flax
import hydra
import jax
import jax.numpy as jnp
import numpy as np
import optax
from colorama import Fore, Style
from flax.core.frozen_dict import FrozenDict
from jumanji.env import Environment
from jumanji.env import specs
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint

from stoix.networks.inputs import ObservationInput
from stoix.utils.flatten_util import ravel_pytree
import stoix.utils.running_statistics as running_statistics
from stoix.base_types import ActorApply, CriticApply, ExperimentOutput, LearnerFn
from stoix.evaluator_multi_seed import evaluator_setup, get_distribution_act_fn
from stoix.networks.base import FeedForwardActor as Actor
from stoix.networks.base import FeedForwardCritic as Critic
from stoix.systems.ppo.ppo_types import (
    ActorCriticOuterOptStates,
    ActorCriticParams,
    LearnerStateObsNorm,
    PPOTransition,
)
from stoix.utils import make_env as environments
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.free_step import get_diff_gradient
from stoix.utils.jax_utils import (
    merge_leading_dims,
    unreplicate_n_dims,
)
from stoix.utils.logger import LogEvent, MultiSeedStoixLogger, StoixLogger, get_logger_path
from stoix.utils.loss import (
    clipped_value_loss,
    dpo_loss,
    ppo_clip_loss,
    ppo_penalty_loss,
    unclipped_value_loss,
)
from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation
from stoix.utils.total_timestep_checker_multi_seed import check_total_timesteps
from stoix.utils.training import make_learning_rate
from stoix.wrappers.episode_metrics import get_final_step_metrics


def get_learner_fn(
    env: Environment,
    apply_fns: Tuple[ActorApply, CriticApply],
    update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn, optax.TransformUpdateFn],
    config: DictConfig,
) -> LearnerFn[LearnerStateObsNorm]:
    """Get the learner function."""

    # Get apply and update functions for actor and critic networks.
    actor_apply_fn, critic_apply_fn = apply_fns
    actor_update_fn, critic_update_fn, outer_update_fn, obs_norm_update_fn, free_step_update_fn = update_fns

    def _update_step(
        learner_state: LearnerStateObsNorm, _: Any
    ) -> Tuple[LearnerStateObsNorm, Tuple]:
        """A single update of the network.

        This function steps the environment and records the trajectory batch for
        training. It then calculates advantages and targets based on the recorded
        trajectory and updates the actor and critic networks based on the calculated
        losses.

        Args:
            learner_state (NamedTuple):
                - params (ActorCriticParams): The current model parameters.
                - opt_states (OptStates): The current optimizer states.
                - key (PRNGKey): The random number generator state.
                - env_state (State): The environment state.
                - last_timestep (TimeStep): The last timestep in the current trajectory.
            _ (Any): The current metrics info.
        """

        def _env_step(
            learner_state: LearnerStateObsNorm, _: Any
        ) -> Tuple[LearnerStateObsNorm, PPOTransition]:
            """Step the environment."""
            params, opt_states, obs_norm_params, key, env_state, last_timestep, fs_up, out_grad = learner_state

            # SELECT ACTION
            key, policy_key = jax.random.split(key)
            actor_policy = actor_apply_fn(
                params.actor_params, obs_norm_params, last_timestep.observation
            )
            value = critic_apply_fn(
                params.critic_params, obs_norm_params, last_timestep.observation
            )
            action = actor_policy.sample(seed=policy_key)
            log_prob = actor_policy.log_prob(action)

            # STEP ENVIRONMENT
            env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

            # LOG EPISODE METRICS
            done = (timestep.discount == 0.0).reshape(-1)
            truncated = (timestep.last() & (timestep.discount != 0.0)).reshape(-1)
            info = timestep.extras["episode_metrics"]

            transition = PPOTransition(
                done,
                truncated,
                action,
                value,
                timestep.reward,
                log_prob,
                last_timestep.observation,
                info,
            )
            learner_state = LearnerStateObsNorm(
                params, opt_states, obs_norm_params, key, env_state, timestep, fs_up, out_grad
            )
            return learner_state, transition

        # STEP ENVIRONMENT FOR ROLLOUT LENGTH
        learner_state, traj_batch = jax.lax.scan(
            _env_step, learner_state, None, config.system.rollout_length
        )
        params, opt_states, obs_norm_params, key, env_state, last_timestep, free_step_updates, prev_out_grad = learner_state

        # UPDATE OBSERVATION NORMALIZATION PARAMETERS
        obs_norm_params = obs_norm_update_fn(obs_norm_params, traj_batch.obs)

        # CALCULATE ADVANTAGE
        last_val = critic_apply_fn(params.critic_params, obs_norm_params, last_timestep.observation)
        r_t = traj_batch.reward * config.system.reward_scaling
        v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0)
        d_t = 1.0 - traj_batch.done.astype(jnp.float32)
        d_t = (d_t * config.system.gamma).astype(jnp.float32)
        advantages, targets = batch_truncated_generalized_advantage_estimation(
            r_t,
            d_t,
            config.system.gae_lambda,
            v_t,
            time_major=True,
            standardize_advantages=config.system.standardize_advantages,
            truncation_flags=traj_batch.truncated,
        )

        # COPY BEHAVIOUR PARAMS
        behaviour_params = params

        # APPLY FREE STEP UPDATES
        params = optax.apply_updates(params, free_step_updates)

        def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
            """Update the network for a single epoch."""

            def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
                """Update the network for a single minibatch."""

                # UNPACK TRAIN STATE AND BATCH INFO
                params, opt_states, key = train_state
                traj_batch, advantages, targets = batch_info

                def _actor_loss_fn(
                    actor_params: FrozenDict,
                    traj_batch: PPOTransition,
                    gae: chex.Array,
                    rng_key: chex.PRNGKey,
                ) -> Tuple:
                    """Calculate the actor loss."""
                    # RERUN NETWORK
                    actor_policy = actor_apply_fn(actor_params, obs_norm_params, traj_batch.obs)
                    log_prob = actor_policy.log_prob(traj_batch.action)

                    # CALCULATE ACTOR LOSS
                    if config.system.loss_actor_type == "ppo-penalty":
                        behaviour_policy = actor_apply_fn(
                            behaviour_params.actor_params, obs_norm_params, traj_batch.obs
                        )
                        loss_actor, _ = ppo_penalty_loss(
                            log_prob,
                            traj_batch.log_prob,
                            gae,
                            config.system.kl_penalty_coef,
                            actor_policy,
                            behaviour_policy,
                        )
                    elif config.system.loss_actor_type == "ppo-clip":
                        loss_actor = ppo_clip_loss(
                            log_prob, traj_batch.log_prob, gae, config.system.clip_eps
                        )
                    elif config.system.loss_actor_type == "dpo":
                        loss_actor = dpo_loss(
                            log_prob,
                            traj_batch.log_prob,
                            gae,
                            config.system.alpha,
                            config.system.beta,
                        )

                    entropy = actor_policy.entropy(seed=rng_key).mean()

                    total_loss_actor = loss_actor - config.system.ent_coef * entropy
                    loss_info = {
                        "actor_loss": loss_actor,
                        "entropy": entropy,
                    }
                    return total_loss_actor, loss_info

                def _critic_loss_fn(
                    critic_params: FrozenDict,
                    traj_batch: PPOTransition,
                    targets: chex.Array,
                ) -> Tuple:
                    """Calculate the critic loss."""
                    # RERUN NETWORK
                    value = critic_apply_fn(critic_params, obs_norm_params, traj_batch.obs)

                    # CALCULATE VALUE LOSS
                    if config.system.loss_critic_type == "clip":
                        value_loss = clipped_value_loss(
                            value, traj_batch.value, targets, config.system.clip_eps
                        )
                    elif config.system.loss_critic_type == "unclip":
                        value_loss = unclipped_value_loss(
                            value,
                            targets,
                        )

                    critic_total_loss = config.system.vf_coef * value_loss
                    loss_info = {
                        "value_loss": value_loss,
                    }
                    return critic_total_loss, loss_info

                # CALCULATE ACTOR LOSS
                key, actor_loss_key = jax.random.split(key)
                actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True)
                actor_grads, actor_loss_info = actor_grad_fn(
                    params.actor_params,
                    traj_batch,
                    advantages,
                    actor_loss_key,
                )

                # CALCULATE CRITIC LOSS
                critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True)
                critic_grads, critic_loss_info = critic_grad_fn(
                    params.critic_params, traj_batch, targets
                )
                
                # WE ARE RUNNING DIFFERENT SEEDS OVER CORES AND UPDATE BATCHES THUS WE DONT 
                # PMEAN THE GRADIENTS
                # Compute the parallel mean (pmean) over the batch.
                # This calculation is inspired by the Anakin architecture demo notebook.
                # available at https://tinyurl.com/26tdzs5x
                # This pmean could be a regular mean as the batch axis is on the same device.
                # actor_grads, actor_loss_info = jax.lax.pmean(
                #     (actor_grads, actor_loss_info), axis_name="batch"
                # )
                # # pmean over devices.
                # actor_grads, actor_loss_info = jax.lax.pmean(
                #     (actor_grads, actor_loss_info), axis_name="device"
                # )

                # critic_grads, critic_loss_info = jax.lax.pmean(
                #     (critic_grads, critic_loss_info), axis_name="batch"
                # )
                # # pmean over devices.
                # critic_grads, critic_loss_info = jax.lax.pmean(
                #     (critic_grads, critic_loss_info), axis_name="device"
                # )

                # UPDATE ACTOR PARAMS AND OPTIMISER STATE
                actor_updates, actor_new_opt_state = actor_update_fn(
                    actor_grads, opt_states.actor_opt_state
                )
                actor_new_params = optax.apply_updates(params.actor_params, actor_updates)

                # UPDATE CRITIC PARAMS AND OPTIMISER STATE
                critic_updates, critic_new_opt_state = critic_update_fn(
                    critic_grads, opt_states.critic_opt_state
                )
                critic_new_params = optax.apply_updates(params.critic_params, critic_updates)

                # PACK NEW PARAMS AND OPTIMISER STATE
                new_params = ActorCriticParams(actor_new_params, critic_new_params)
                new_opt_state = opt_states._replace(
                    actor_opt_state=actor_new_opt_state, critic_opt_state=critic_new_opt_state
                )

                # PACK LOSS INFO
                loss_info = {
                    **actor_loss_info,
                    **critic_loss_info,
                }
                return (new_params, new_opt_state, key), loss_info

            params, opt_states, traj_batch, advantages, targets, key = update_state
            key, shuffle_key = jax.random.split(key)

            # SHUFFLE MINIBATCHES
            batch_size = config.system.rollout_length * config.arch.num_envs
            permutation = jax.random.permutation(shuffle_key, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch)
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])),
                shuffled_batch,
            )

            # UPDATE MINIBATCHES
            (params, opt_states, key), loss_info = jax.lax.scan(
                _update_minibatch, (params, opt_states, key), minibatches
            )

            update_state = (params, opt_states, traj_batch, advantages, targets, key)
            return update_state, loss_info

        update_state = (params, opt_states, traj_batch, advantages, targets, key)

        # UPDATE EPOCHS - i.e Perform PPO Iteration
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, config.system.epochs
        )

        latest_params, opt_states, traj_batch, advantages, targets, key = update_state

        # OUTER STEP
        outer_grads = get_diff_gradient(latest_params, behaviour_params)
        outer_updates, outer_new_opt_state = outer_update_fn(
            outer_grads, opt_states.outer_opt_state
        )
        # Apply the updates to the parameters
        new_params = optax.apply_updates(behaviour_params, outer_updates)

        # UPDATE FREE STEP OPTIMISER AND GET THE UPDATES WE WILL APPLY 
        # BEFORE THE NEXT PPO ITERATION
        free_step_updates, free_step_new_opt_state = free_step_update_fn(
            outer_grads, opt_states.free_step_opt_state
        )
        
        opt_states = opt_states._replace(outer_opt_state=outer_new_opt_state, free_step_opt_state=free_step_new_opt_state)
        
        # Get Outer Gradient Metrics
        prev_out_grad_vector,_ = ravel_pytree(prev_out_grad)
        new_out_grad_vector,_ = ravel_pytree(outer_grads)

        def cosine_similarity(x, y):
            return jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y) + 1e-8)
        cosine_sim = cosine_similarity(new_out_grad_vector, prev_out_grad_vector)
        out_grad_norm = optax.global_norm(outer_grads)

        loss_info["cosine_similarity"] = cosine_sim
        loss_info["outer_grad_norm"] = out_grad_norm

        learner_state = LearnerStateObsNorm(
            new_params, opt_states, obs_norm_params, key, env_state, last_timestep, free_step_updates, outer_grads
        )
        metric = traj_batch.info
        return learner_state, (metric, loss_info)

    def learner_fn(learner_state: LearnerStateObsNorm) -> ExperimentOutput[LearnerStateObsNorm]:
        """Learner function.

        This function represents the learner, it updates the network parameters
        by iteratively applying the `_update_step` function for a fixed number of
        updates. The `_update_step` function is vectorized over a batch of inputs.

        Args:
            learner_state (NamedTuple):
                - params (ActorCriticParams): The initial model parameters.
                - opt_states (OptStates): The initial optimizer state.
                - key (chex.PRNGKey): The random number generator state.
                - env_state (LogEnvState): The environment state.
                - timesteps (TimeStep): The initial timestep in the initial trajectory.
        """

        batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch")

        learner_state, (episode_info, loss_info) = jax.lax.scan(
            batched_update_step, learner_state, None, config.arch.num_updates_per_eval
        )
        return ExperimentOutput(
            learner_state=learner_state,
            episode_metrics=episode_info,
            train_metrics=loss_info,
        )

    return learner_fn


def learner_setup(
    env: Environment, keys: chex.Array, config: DictConfig
) -> Tuple[LearnerFn[LearnerStateObsNorm], Actor, LearnerStateObsNorm]:
    """Initialise learner_fn, network, optimiser, environment and states."""

    if len(config.parallel_seeds) <= len(jax.devices()):
        # have a seed per device
        n_devices = len(config.parallel_seeds)
    elif len(jax.devices()) == 1:
        n_devices = 1
    else:
        raise ValueError("either one seed per device or all seeds on one device")

    # Get number/dimension of actions.
    action_spec = env.action_spec()
    if isinstance(action_spec, specs.DiscreteArray):
        num_actions = int(action_spec.num_values)
        network_kwargs = {}
    else:
        num_actions = int(env.action_spec().shape[-1])
        config.system.action_minimum = float(env.action_spec().minimum)
        config.system.action_maximum = float(env.action_spec().maximum)
        network_kwargs = {"minimum":config.system.action_minimum, "maximum":config.system.action_maximum,}
    
    config.system.action_dim = num_actions
    
    # PRNG keys.
    key, actor_net_key, critic_net_key = keys

    # Define network and optimiser.
    try:
        actor_input_layer = hydra.utils.instantiate(config.network.actor_network.input_layer)
    except:
        actor_input_layer = ObservationInput()
    actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
    actor_action_head = hydra.utils.instantiate(
        config.network.actor_network.action_head,
        action_dim=num_actions,
        **network_kwargs,
    )
    try:
        critic_input_layer = hydra.utils.instantiate(config.network.critic_network.input_layer)
    except:
        critic_input_layer = ObservationInput()
    critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
    critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head)

    actor_network = Actor(torso=actor_torso, action_head=actor_action_head, input_layer=actor_input_layer)
    critic_network = Critic(torso=critic_torso, critic_head=critic_head, input_layer=critic_input_layer)

    actor_lr = make_learning_rate(
        config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches
    )
    critic_lr = make_learning_rate(
        config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches
    )

    # These minus scaled gradients
    actor_optim = optax.chain(
        optax.clip_by_global_norm(config.system.max_grad_norm),
        optax.adam(actor_lr, eps=1e-5),
    )
    critic_optim = optax.chain(
        optax.clip_by_global_norm(config.system.max_grad_norm),
        optax.adam(critic_lr, eps=1e-5),
    )

    if type(config.system.outer_optimizer.learning_rate) is not float:
        config.system.outer_optimizer.learning_rate.transition_steps = config.arch.num_updates

    config.system.free_step_learning_rate.transition_steps = config.arch.num_updates

    # This adds scaled gradients
    outer_optim = optax.chain(
        optax.scale(-1),
        hydra.utils.instantiate(config.system.outer_optimizer),
    )

    free_step_optim = optax.chain(
        optax.scale(-1),
        optax.ema(config.system.free_step_momentum),
        optax.inject_hyperparams(optax.scale_by_learning_rate)(learning_rate=hydra.utils.instantiate(config.system.free_step_learning_rate))
        )
    
    # Initialise observation
    init_x = env.observation_spec().generate_value()
    init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x)

    # Initialise observation normalization parameters
    obs_norm_params = running_statistics.init_state(
        jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x)
    )

    # Define observation normalisation functions.
    obs_norm_update_fn = running_statistics.update
    
    if config.system.normalize_observations:
        obs_norm_apply_fn = running_statistics.normalize
    else:
        obs_norm_apply_fn = lambda x, _: x

    # Initialise actor params and optimiser state.
    # Create params for each seed on each device as well as each seed that is vmapped. 
    # This will create params of the shape (devices, seeds_per_device, ...)
    actor_params = jax.pmap(jax.vmap(actor_network.init, axis_name="batch", in_axes=(0, None)), axis_name="device", in_axes=(0,None))(actor_net_key, init_x)
    actor_opt_state = jax.pmap(jax.vmap(actor_optim.init, axis_name="batch"), axis_name="device")(actor_params)

    # Initialise critic params and optimiser state.
    critic_params = jax.pmap(jax.vmap(critic_network.init, axis_name="batch", in_axes=(0, None)), axis_name="device", in_axes=(0,None))(critic_net_key, init_x)
    critic_opt_state = jax.pmap(jax.vmap(critic_optim.init, axis_name="batch"), axis_name="device")(critic_params)

    # Pack params.
    params = ActorCriticParams(actor_params, critic_params)

    outer_opt_state = jax.pmap(jax.vmap(outer_optim.init, axis_name="batch"), axis_name="device")(params)
    free_step_opt_state = jax.pmap(jax.vmap(free_step_optim.init, axis_name="batch"), axis_name="device")(params)

    # Here we define the network apply functions with normalization.
    actor_network_apply_fn = lambda act_params, norm_params, obs: actor_network.apply(
        act_params, obs_norm_apply_fn(obs, jax.lax.stop_gradient(norm_params))
    )
    critic_network_apply_fn = lambda crit_params, norm_params, obs: critic_network.apply(
        crit_params, obs_norm_apply_fn(obs, jax.lax.stop_gradient(norm_params))
    )

    # Pack apply and update functions.
    apply_fns = (actor_network_apply_fn, critic_network_apply_fn)
    update_fns = (actor_optim.update, critic_optim.update, outer_optim.update, obs_norm_update_fn, free_step_optim.update)

    # Get batched iterated update and replicate it to pmap it over cores.
    learn = get_learner_fn(env, apply_fns, update_fns, config)
    learn = jax.pmap(learn, axis_name="device")

    # Initialise environment states and timesteps: across devices and batches.
    # The key is of shape (devices, seeds_per_device, 2)
    # We reshape it to (devices*seeds_per_device, 2) for eas of use
    key = key.reshape(n_devices*config.arch.num_seeds_per_device, 2)
    # we then create a subkey for each seed on each device
    key, sub_key= jax.vmap(jax.random.split, out_axes=1)(key)
    # subkey is of shape (devices*seeds_per_device, 2)
    # we use this subkey to create a key for each parallel environment
    # for each seed on each device - this makes the shape of env_keys
    # (devices*seeds_per_device, num_envs, 2)
    env_keys = jax.vmap(jax.random.split, in_axes=(0,None))(
        sub_key, config.arch.num_envs
    )
    # we create the env_states and timesteps for each parallel environment on each seed on each device
    env_states, timesteps = jax.vmap(jax.vmap(env.reset, in_axes=(0)))(
        jnp.stack(env_keys),
    )
    # reshape the env_states and timesteps to (devices, seeds_per_device, num_envs, ...)
    reshape_states = lambda x: x.reshape(
        (n_devices, config.arch.num_seeds_per_device, config.arch.num_envs) + x.shape[2:]
    )
    env_states = jax.tree_util.tree_map(reshape_states, env_states)
    timesteps = jax.tree_util.tree_map(reshape_states, timesteps)

    # Load model from checkpoint if specified.
    if config.logger.checkpointing.load_model:
        loaded_checkpoint = Checkpointer(
            model_name=config.system.system_name,
            **config.logger.checkpointing.load_args,  # Other checkpoint args
        )
        # Restore the learner state from the checkpoint
        restored_params, _ = loaded_checkpoint.restore_params()
        # Update the params
        params = restored_params

    # Define params to be replicated across devices and batches.
    # we make sure the key is in the correct shape
    key = key.reshape(n_devices*config.arch.num_seeds_per_device, 2)
    # we make a subkey for each seed on each device
    key, step_key = jax.vmap(jax.random.split, out_axes=1)(key)
    # step_key is of shape (devices*seeds_per_device, 2)
    # we then reshape it to (devices, seeds_per_device, 2)
    step_keys = step_key.reshape(n_devices, config.arch.num_seeds_per_device, 2)
    reshape_keys = lambda x: x.reshape((n_devices, config.arch.num_seeds_per_device) + x.shape[2:])
    step_keys = reshape_keys(jnp.stack(step_keys))
    opt_states = ActorCriticOuterOptStates(actor_opt_state, critic_opt_state, outer_opt_state, free_step_opt_state)

    # we then broadcast the obs_norm_params across devices and seeds
    replicate_learner = obs_norm_params

    # Duplicate learner for update_batch_size.
    broadcast = lambda x: jnp.broadcast_to(x, (config.arch.num_seeds_per_device,) + x.shape)
    replicate_learner = jax.tree_util.tree_map(broadcast, replicate_learner)

    # Duplicate learner across devices.
    replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()[:n_devices])

    # Initialise learner state.
    obs_norm_params = replicate_learner
    init_learner_state = LearnerStateObsNorm(
        params, opt_states, obs_norm_params, step_keys, env_states, timesteps, jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params) , jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params)
    )

    return learn, actor_network_apply_fn, init_learner_state


def run_experiment(_config: DictConfig) -> float:
    """Runs experiment."""
    config = copy.deepcopy(_config)

    if not config.system.run_outer_ppo:
        if config.system.free_step_learning_rate.peak_value != 0.0:
            raise ValueError("Free step learning rate peak value must be 0.0 if run_outer_ppo is False")
        if config.system.outer_optimizer.momentum is not None:
            if config.system.outer_optimizer.init_momentum != 0.0:
                raise ValueError("Outer optimizer learning init_momentum must be 0.0 if run_outer_ppo is False")
        if hasattr(config.system.outer_optimizer.learning_rate, 'peak_value'):
            if config.system.outer_optimizer.learning_rate.peak_value != 1.0:
                raise ValueError("Outer optimizer learning rate peak value must be 1.0 if run_outer_ppo is False")
            if config.system.outer_optimizer.learning_rate.div_factor != 1.0:
                raise ValueError("Outer optimizer learning rate div_factor must be 1.0 if run_outer_ppo is False")
            if config.system.outer_optimizer.learning_rate.final_div_factor != 1.0:
                raise ValueError("Outer optimizer learning rate final_div_factor must be 1.0 if run_outer_ppo is False")
        elif config.system.outer_optimizer.learning_rate != 1.0:
            raise ValueError("Outer optimizer learning rate must be 1.0 if run_outer_ppo is False")

    if len(config.parallel_seeds) <= len(jax.devices()):
        # have a seed per device
        n_devices = len(config.parallel_seeds)
    elif len(jax.devices()) == 1:
        n_devices = 1
    else:
        raise ValueError("either one seed per device or all seeds on one device")

    config.num_devices = n_devices
    config = check_total_timesteps(config)
    assert (
        config.arch.num_updates > config.arch.num_evaluation
    ), "Number of updates per evaluation must be less than total number of updates."

    # Create the environments for train and eval.
    env, eval_env = environments.make(config=config)

    # PRNG keys.
    num_seeds = len(config.parallel_seeds)
    config.arch.num_seeds = num_seeds
    seeds_per_device = num_seeds // n_devices
    config.arch.num_seeds_per_device = seeds_per_device
    # Create all seeds
    seed_keys = [jax.random.PRNGKey(seed) for seed in config.parallel_seeds]
    gen_keys = []
    e_keys = []
    actor_net_keys = []
    critic_net_keys = []
    for seed in seed_keys:
        key, key_e, actor_net_key, critic_net_key = jax.random.split(
            seed, num=4
        )
        gen_keys.append(key)
        e_keys.append(key_e)
        actor_net_keys.append(actor_net_key)
        critic_net_keys.append(critic_net_key)
    
    # Stack keys and reshape to (n_devices, seeds_per_device, -1)
    key = jnp.stack(gen_keys).reshape(n_devices, seeds_per_device, -1)
    key_e = jnp.stack(e_keys).reshape(n_devices, seeds_per_device, -1)
    actor_net_key = jnp.stack(actor_net_keys).reshape(n_devices, seeds_per_device, -1)
    critic_net_key = jnp.stack(critic_net_keys).reshape(n_devices, seeds_per_device, -1)

    # Setup learner.
    learn, actor_network_apply_fn, learner_state = learner_setup(
        env, (key, actor_net_key, critic_net_key), config
    )

    # Setup evaluator.
    evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup(
        eval_env=eval_env,
        key_e=key_e,
        eval_act_fn=get_distribution_act_fn(config, actor_network_apply_fn),
        params=learner_state.params.actor_params,
        config=config,
    )

    # Calculate number of updates per evaluation.
    config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
    steps_per_rollout = (
        config.arch.num_updates_per_eval
        * config.system.rollout_length
        * config.arch.num_envs
    )

    # Logger setup
    logger = MultiSeedStoixLogger(config)
    cfg: Dict = OmegaConf.to_container(config, resolve=True)
    cfg["arch"]["devices"] = len(config.parallel_seeds)

    pprint(cfg)

    # Set up checkpointer
    save_checkpoint = config.logger.checkpointing.save_model
    if save_checkpoint:
        checkpointer = Checkpointer(
            metadata=config,  # Save all config as metadata in the checkpoint
            model_name=config.system.system_name,
            **config.logger.checkpointing.save_args,  # Checkpoint args
        )

    # Run experiment for a total number of evaluations.
    max_episode_return = [jnp.float32(-1e7) for _ in range(n_devices*seeds_per_device)]
    # Create a list of best params and obs_norm_params for each seed on each device
    init_params = jax.tree.map(lambda x: merge_leading_dims(x,2), learner_state.params.actor_params)
    init_obs_norm_params = jax.tree.map(lambda x: merge_leading_dims(x,2), learner_state.obs_norm_params)
    best_params = [jax.tree.map(lambda x : x[i], init_params) for i in range(n_devices*seeds_per_device)]
    best_obs_norm_params = [jax.tree.map(lambda x : x[i], init_obs_norm_params) for i in range(n_devices*seeds_per_device)]
    
    # Perform Initial Evaluation Before Training
    start_time = time.time()
    trained_params = learner_state.params.actor_params
        # Select only actor params
    obs_norm_params = learner_state.obs_norm_params
    # key_e is of shape (devices, seeds_per_device, 2)
    key_e = key_e.reshape(n_devices*seeds_per_device, 2)
    key_e, eval_keys= jax.vmap(jax.random.split, out_axes=1)(key_e)
    key_e = key_e.reshape(n_devices, seeds_per_device, 2)
    eval_keys = jnp.stack(eval_keys).reshape(n_devices, seeds_per_device, -1)
    # eval_keys is of shape (devices, seeds_per_device, 2)

    # Evaluate.
    evaluator_output = evaluator(trained_params, obs_norm_params, eval_keys)
    jax.block_until_ready(evaluator_output)

    # Log the results of the evaluation.
    elapsed_time = time.time() - start_time
    evaluator_output = jax.tree.map(lambda x: merge_leading_dims(x,2), evaluator_output)
    steps_per_eval = jnp.sum(evaluator_output.episode_metrics["episode_length"], axis=-1)
    evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
    for i in range(n_devices*seeds_per_device):
        eval_metrics = jax.tree.map(lambda x: x[i], evaluator_output.episode_metrics)
        logger.log(eval_metrics, 0, 0, LogEvent.EVAL, seed=i)
    
    for eval_step in range(1, config.arch.num_evaluation+1):
        # Train.
        start_time = time.time()

        learner_output = learn(learner_state)
        jax.block_until_ready(learner_output)

        # Log the results of the training.
        elapsed_time = time.time() - start_time
        t = int(steps_per_rollout * eval_step)
        # learner_output.episode_metrics is of shape (num_devices, updates_per_eval, num_seeds_per_device, rollout_length, num_envs)
        episode_metrics = learner_output.episode_metrics
        episode_metrics = jax.tree.map(lambda x: jnp.swapaxes(x, 1, 2), episode_metrics)
        # episode_metrics is now of shape (num_devices, num_seeds_per_device, updates_per_eval, rollout_length, num_envs)
        episode_metrics = jax.tree.map(lambda x: jnp.reshape(x, (n_devices*seeds_per_device,)+x.shape[2:]), episode_metrics)
        # episode_metrics is now of shape (devices*seeds_per_device, updates_per_eval, rollout_length, num_envs)
        seed_metrics = []
        seed_ep_completed = []
        for i in range(n_devices*seeds_per_device):
            ep_metrics = jax.tree.map(lambda x: x[i], episode_metrics)
            ep_metrics, ep_completed = get_final_step_metrics(ep_metrics)
            ep_metrics["steps_per_second"] = steps_per_rollout / elapsed_time
            seed_metrics.append(ep_metrics)
            seed_ep_completed.append(ep_completed)

        # Separately log timesteps, actoring metrics and training metrics.
        for i, ep_metrics in enumerate(seed_metrics):
            logger.log({"timestep": t}, t, eval_step, LogEvent.MISC, seed=i)
            if seed_ep_completed[i]:  # only log episode metrics if an episode was completed in the rollout.
                logger.log(ep_metrics, t, eval_step, LogEvent.ACT, seed=i)
    
        train_metrics = learner_output.train_metrics
        train_metrics = jax.tree.map(lambda x: jnp.swapaxes(x, 1, 2), train_metrics)
        train_metrics = jax.tree.map(lambda x: jnp.reshape(x, (n_devices*seeds_per_device,)+x.shape[2:]), train_metrics)
        for i in range(n_devices*seeds_per_device):
            metrics = jax.tree.map(lambda x: x[i], train_metrics)
            logger.log(metrics, t, eval_step, LogEvent.TRAIN, seed=i)
            unique_token = logger.unique_tokens[i]

            cosine_sim = metrics["cosine_similarity"]
            grad_norm = metrics["outer_grad_norm"]
            chex.assert_shape(cosine_sim, (config.arch.num_updates_per_eval,))
            chex.assert_shape(grad_norm, (config.arch.num_updates_per_eval,))

            json_exp_path = get_logger_path(config, "json")
            path = os.path.join(
            config.logger.base_exp_path, f"{json_exp_path}/{unique_token}/cosine_sim.txt"
        )
            with open(path, "a") as f:
                f.write(f"{np.array(cosine_sim).tolist()} {np.array(grad_norm).tolist()}\n")

        # check for any nans in the actor loss, stop the experiment if there are any
        if jnp.isnan(learner_output.train_metrics['actor_loss']).sum():
            return float('nan')

        # Prepare for evaluation.
        start_time = time.time()
        trained_params = learner_output.learner_state.params.actor_params
          # Select only actor params
        obs_norm_params = learner_output.learner_state.obs_norm_params
        # key_e is of shape (devices, seeds_per_device, 2)
        key_e = key_e.reshape(n_devices*seeds_per_device, 2)
        key_e, eval_keys= jax.vmap(jax.random.split, out_axes=1)(key_e)
        key_e = key_e.reshape(n_devices, seeds_per_device, 2)
        eval_keys = jnp.stack(eval_keys).reshape(n_devices, seeds_per_device, -1)
        # eval_keys is of shape (devices, seeds_per_device, 2)

        # Evaluate.
        evaluator_output = evaluator(trained_params, obs_norm_params, eval_keys)
        jax.block_until_ready(evaluator_output)

        # Log the results of the evaluation.
        elapsed_time = time.time() - start_time
        evaluator_output = jax.tree.map(lambda x: merge_leading_dims(x,2), evaluator_output)
        steps_per_eval = jnp.sum(evaluator_output.episode_metrics["episode_length"], axis=-1)
        evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
        for i in range(n_devices*seeds_per_device):
            eval_metrics = jax.tree.map(lambda x: x[i], evaluator_output.episode_metrics)
            logger.log(eval_metrics, t, eval_step, LogEvent.EVAL, seed=i)
        
        
        episode_returns = jnp.mean(evaluator_output.episode_metrics["episode_return"], axis=-1)


        # NOT SUPPORTED FOR MULTISEED
        # if save_checkpoint:
        #     # Save checkpoint of learner state
        #     save_learner_state = jax.tree.map(lambda x: merge_leading_dims(x,2), learner_output.learner_state)
        #     for i in range(n_devices*seeds_per_device):
        #         checkpointer.save(
        #             timestep=int(steps_per_rollout * eval_step),
        #             unreplicated_learner_state=jax.tree.map(lambda x: x[i], save_learner_state),
        #             episode_return=episode_returns[i],
        #             seed=i
        #         )

        trained_params = jax.tree.map(lambda x: merge_leading_dims(x,2), trained_params)
        obs_norm_params = jax.tree.map(lambda x: merge_leading_dims(x,2), obs_norm_params)
        for i in range(n_devices*seeds_per_device):
            if config.arch.absolute_metric and max_episode_return[i] <= episode_returns[i]:
                trained_params_seed = jax.tree.map(lambda x: x[i], trained_params)
                obs_norm_params_seed = jax.tree.map(lambda x: x[i], obs_norm_params)
                best_params[i] = copy.deepcopy(trained_params_seed)
                best_obs_norm_params[i] = copy.deepcopy(obs_norm_params_seed)
                max_episode_return[i] = episode_returns[i]
       
    #     # Update runner state to continue training.
        learner_state = learner_output.learner_state

    # # Measure absolute metric.
    if config.arch.absolute_metric:
        start_time = time.time()
        key_e = key_e.reshape(n_devices*seeds_per_device, 2)
        key_e, eval_keys= jax.vmap(jax.random.split, out_axes=1)(key_e)
        eval_keys = jnp.stack(eval_keys)
        eval_keys = eval_keys.reshape(n_devices, seeds_per_device, -1)
        
        def tree_stack(trees):
            return jax.tree.map(lambda *v: jnp.stack(v), *trees)
        
        best_params = tree_stack(best_params)
        best_obs_norm_params = tree_stack(best_obs_norm_params)
        best_params = jax.tree.map(lambda x: jnp.reshape(x, (n_devices, seeds_per_device,)+x.shape[1:]), best_params)
        best_obs_norm_params = jax.tree.map(lambda x: jnp.reshape(x, (n_devices, seeds_per_device,)+x.shape[1:]), best_obs_norm_params)
        evaluator_output = absolute_metric_evaluator(best_params, best_obs_norm_params, eval_keys)
        jax.block_until_ready(evaluator_output)

        evaluator_output = jax.tree.map(lambda x: jnp.reshape(x, (n_devices*seeds_per_device,)+x.shape[2:]), evaluator_output)

        elapsed_time = time.time() - start_time
        t = int(steps_per_rollout * eval_step)
        steps_per_eval = jnp.sum(evaluator_output.episode_metrics["episode_length"], axis=-1)
        evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time
        for i in range(n_devices*seeds_per_device):
            eval_metrics = jax.tree.map(lambda x: x[i], evaluator_output.episode_metrics)
            logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE, seed=i)
        
    # # Stop the logger.
    logger.stop()
    # Record the performance for the final evaluation run. If the absolute metric is not
    # calculated, this will be the final evaluation run.
    eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric]))

    if config.logger.kwargs.neptune_mode == 'offline':
        # since we are using several nodes, and in some cases several parallel runs
        # per node we found async (even with larger flush_period) to lead to many 
        # connection issues. this command will sync the offline neptune logs to neptune
        # without maintaining a connection throughout the run.
        subprocess.run(['neptune', 'sync', '-p', config.logger.kwargs.neptune_project])

    return eval_performance


@hydra.main(
    config_path="../../configs",
    config_name="default_ff_ppo.yaml",
    version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
    """Experiment entry point."""
    # Allow dynamic attributes.
    OmegaConf.set_struct(cfg, False)

    # Run experiment.
    eval_performance = run_experiment(cfg)
    print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")
    return eval_performance


if __name__ == "__main__":
    hydra_entry_point()
