"""
Based on PureJaxRL Implementation of PPO.
"""
import os
from pathlib import Path
import pickle
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pyrallis
from typing import NamedTuple
from flax.training.train_state import TrainState
import distrax
from flax import struct
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer

import jax.scipy.stats

from dataclasses import asdict, dataclass

from src.agents.actors import ActorWithConditionalCritic, ScannedRNN, ActorCriticRNN

import wandb

from src.envs.ogc.auto_replay_wrapper import AutoReplayWrapper
from src.envs.ogc.ogc import OGC, Level, make_level_generator


ACTION_SPACE_SIZE = 6


@dataclass
class TrainConfig:
    # Wandb and other logging
    project: str = "JaxZSC"
    mode: str = "disabled"  # Literal["online", "offline", "disabled"]
    entity: str = ""
    checkpoint_path: str = "checkpoints"
    checkpoint_freq: int = 25  # Checkpoint every N updates

    # OGC
    env_name: str = "ogc"
    # Literal["cramped_room", "asymm_advantages", "coord_ring", "forced_coord", "counter_circuit"]
    # layout_name: str = "cramped_room"
    ogc_width: int = 5
    ogc_height: int = 5
    ogc_n_walls: int = 3
    random_reset: bool = False
    # eval_level: list["str"] = []
    rew_shaping_horizon: float = 3e9

    # Actor-Critic
    activation: str = "tanh"
    fc_dim_size: int = 256
    gru_hidden_dim: int = 256

    embedding_layers: int = 5
    actor_layers: int = 5
    critic_layers: int = 5

    use_layernorm: bool = True

    other_agent_prediction: bool = True
    moa_coef: float = 1.0

    # Training
    seed: int = 0
    lr: float = 1e-3
    anneal_lr: bool = True
    num_envs: int = 1024

    num_steps_per_env: int = 400
    num_steps_per_update: int = 400

    total_timesteps: float = 1e10
    update_epochs: int = 6
    num_minibatches: int = 8
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 1.0
    max_grad_norm: float = 0.5

    # SFL
    # how many times more steps to rollout than the max_steps
    sfl_rollout_factor: int = 10
    sfl_buffer_size: int = 1024
    # sfl_batch_size: int = 16384
    sfl_num_batches: int = 1
    sfl_buffer_refresh_freq: int = 4
    sfl_num_envs_to_sample: int = 1024  # 0 is random both!
    sfl_num_ego_envs_to_sample: int = 0  # How many ego levels to include

    sfl_num_levels: int = 8192
    sfl_num_partner: int = 1

    # gaussian-weighted-standarddev, variance-x-mean, variance, mean-return, success-rate-over-global-median
    learnability_function: str = "variance"
    sample_with_dense_rewards: bool = False

    num_devices: int = 1

    eval_against_pop: bool = True

    log_num_images: int = 20  # number of images to log
    log_images_update: int = 10

    def __post_init__(self):
        self.num_agents = 2
        self.num_actors = self.num_agents * self.num_envs
        self.sfl_num_envs_to_generate = self.num_envs - self.sfl_num_envs_to_sample

        self.sfl_batch_size = self.sfl_num_levels * self.sfl_num_partner

        self.num_inner_updates = self.num_steps_per_env // self.num_steps_per_update
        self.num_envs_per_device = self.num_envs // self.num_devices

        self.minibatch_size = self.num_actors * \
            self.num_steps_per_update // self.num_minibatches
        self.total_timesteps_per_device = self.total_timesteps // self.num_devices
        self.num_meta_updates = round(
            self.total_timesteps_per_device /
            (self.num_envs_per_device * self.num_steps_per_env)
        )
        self.num_reward_shaping_updates = round(
            self.rew_shaping_horizon /
            (self.num_envs_per_device * self.num_steps_per_env)
        )
        self.num_outer_steps = self.num_meta_updates // self.sfl_buffer_refresh_freq

        print('num inner updates', self.num_inner_updates)


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


class PartnerParametersWithBias(struct.PyTreeNode):
    epsilon: jnp.float32
    epsilon_agent: jnp.int32
    bias_mask: jnp.ndarray  # shape (ACTION_SPACE_SIZE,)


def rollout_nsteps(
    rng: jax.Array,
    env,
    level_params: Level,
    train_state: TrainState,
    init_hstate: jax.Array,
    num_steps: int = 400,
    sfl_rollout_factor: int = 2,
    use_dense_rewards: bool = False,
    dense_rewards_coeff: jax.Array = jnp.asarray(0.0),
):
    """ Rollout for `num_steps` environment steps """

    class RolloutEpisodeStats(struct.PyTreeNode):
        reward: jax.Array = jnp.asarray(0.0)
        episode_return: jax.Array = jnp.zeros((sfl_rollout_factor,))
        length: jax.Array = jnp.asarray(0)
        episode_counter: jax.Array = jnp.asarray(0)
        done: jax.Array = jnp.asarray(False)

    def _env_step(carry, unused):
        rng, env_state, stats, last_obs, epsilon_agent, last_done, hstate, past_5_sa_pairs = carry
        rng, rng_action, rng_step = jax.random.split(rng, 3)

        obs_batch = batchify(last_obs, env.agents, 2)

        def get_e3t_action(args):
            pi_ego, k, e3t_epsilon, bias_mask = args
            pi_random = distrax.Categorical(probs=bias_mask)
            pi_e3t_probs = (1 - e3t_epsilon) * pi_ego.probs + \
                e3t_epsilon * pi_random.probs
            pi_e3t = distrax.Categorical(probs=pi_e3t_probs)
            sampled_a = pi_e3t.sample(seed=k)
            log_prob_a = pi_e3t.log_prob(sampled_a)
            entropy_a = pi_e3t.entropy()
            return sampled_a, log_prob_a, entropy_a

        def get_base_action(args):
            pi_ego, k, e3t_epsilon, bias_mask = args
            sampled_a = pi_ego.sample(seed=k)
            log_prob_a = pi_ego.log_prob(sampled_a)
            entropy_a = pi_ego.entropy()
            return sampled_a, log_prob_a, entropy_a

        batched_sa_pairs = batchify_nested_dics(
            past_5_sa_pairs, env.agents, (1, 2, 5))
        ac_in = (
            obs_batch[np.newaxis, :],
            last_done[np.newaxis, :],
            batched_sa_pairs,
        )

        hstate, pi, value, other_pi = train_state.apply_fn(
            train_state.params, hstate, ac_in)

        rng, _rng = jax.random.split(rng, 2)
        e3t_action, e3t_log_prob, e3t_entropy = get_e3t_action(
            (pi, _rng, 0.5, jnp.ones((ACTION_SPACE_SIZE,)) / ACTION_SPACE_SIZE))
        base_action, base_log_prob, base_entropy = get_base_action(
            (pi, _rng, 0.5, jnp.ones((ACTION_SPACE_SIZE,)) / ACTION_SPACE_SIZE))

        epsilon_agent_both = jnp.array(
            [epsilon_agent, ~epsilon_agent])
        action = jnp.where(epsilon_agent_both, e3t_action, base_action)
        action = action.squeeze()

        log_prob = jnp.where(epsilon_agent_both,
                             e3t_log_prob, base_log_prob).squeeze()
        entropy = jnp.where(epsilon_agent_both, e3t_entropy,
                            base_entropy).squeeze()

        env_act = unbatchify(action, env.agents, 1, env.num_agents)
        env_act = {k: v.flatten().squeeze() for k, v in env_act.items()}

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        # STEP ENV
        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)

        done_flag = done["__all__"]
        final_episode_return = stats.reward + reward["agent_0"]
        if use_dense_rewards:
            final_episode_return = final_episode_return + dense_rewards_coeff * (
                info["shaped_reward"]["agent_0"] +
                info["shaped_reward"]["agent_1"]
            )

        # Update buffer only if done, otherwise leave as is
        new_episode_return = jax.lax.cond(
            done_flag,
            lambda: stats.episode_return.at[stats.episode_counter].set(
                final_episode_return),
            lambda: stats.episode_return,
        )

        # Reset reward accumulator if done
        new_reward = jax.lax.cond(
            done_flag,
            lambda: jnp.array(0.0),
            lambda: final_episode_return,
        )

        # Increment episode counter only if done
        new_counter = stats.episode_counter + done_flag.astype(jnp.int32)

        # Update stats
        stats = stats.replace(
            reward=new_reward,
            length=stats.length + 1,
            done=done_flag,
            episode_counter=new_counter,
            episode_return=new_episode_return,
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, epsilon_agent,
                 done.squeeze(), hstate, past_5_sa_pairs)
        return carry, None

    key, key_r, key_p = jax.random.split(rng, 3)
    obs, state = env.reset_env_to_level(
        key_r, level_params, env.default_params)

    epsilon_agent = jax.random.bernoulli(key_p)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_carry = (rng, state, RolloutEpisodeStats(), obs, epsilon_agent,
                  jnp.array([False, False]), init_hstate, past_5_sa_pairs)

    final_carry, _ = jax.lax.scan(
        _env_step, init_carry, None, length=num_steps)

    return final_carry[2], final_carry[1]


def rollout(rng, env, level, network, params, hidden_size):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_step = jax.random.split(rng, 3)

        obs_batch = batchify(last_obs, env.agents, 2)
        batched_sa_pairs = batchify_nested_dics(
            past_5_sa_pairs, env.agents, (1, 2, 5))
        ac_in = (
            obs_batch[np.newaxis, :],
            done[np.newaxis, :],
            batched_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action = pi.sample(seed=rng_action).squeeze()

        env_act = unbatchify(action, env.agents, 1, env.num_agents)
        env_act = {k: v.flatten().squeeze() for k, v in env_act.items()}

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset_env_to_level(key_r, level, env.default_params)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        2, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_single_l(rng, env, level, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
        ac_in = (
            last_obs["agent_0"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )
        hstate, pi, _, _ = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_1"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action0,
            "agent_1": action1
        }

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset_env_to_level(key_r, level, env.default_params)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_0']['obs'] = obs[
        'agent_0'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_single_r(rng, env, level, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
        ac_in = (
            last_obs["agent_1"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_0"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action1,
            "agent_1": action0,
        }

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, - 1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset_env_to_level(key_r, level, env.default_params)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_1']['obs'] = obs['agent_1'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_both_ways(eval_rng, env, level, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize):
    eval_returns_l, _ = jax.vmap(rollout_single_l, in_axes=(0, None, None, None, None, None, None, None, None))(
        eval_rng, env, level, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)

    eval_returns_r, _ = jax.vmap(rollout_single_r, in_axes=(0, None, None, None, None, None, None, None, None))(
        eval_rng, env, level, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)
    return jnp.array([eval_returns_l, eval_returns_r]).mean()


class Transition(NamedTuple):
    global_done: jnp.ndarray
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
    entropy: jnp.ndarray
    hstate: jnp.ndarray
    other_action: jnp.ndarray
    past_sa_pairs: jnp.ndarray


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def batchify_nested_dics(x: dict, agent_list, shape):
    data = [x[a] for a in agent_list]
    tree = jax.tree.map(lambda *v: jnp.stack(v), *data)
    return jax.tree.map(lambda x: x.reshape((*shape, -1)), tree)


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def make_update_fn(config, env, network, sample_random_level, eval_levels, partner_pop_actor, partner_pop_params, pop_size):
    rew_shaping_anneal = optax.linear_schedule(
        init_value=1.,
        end_value=0.,
        transition_steps=config.rew_shaping_horizon
    )

    def log_levels(levels, step):
        rng = jax.random.PRNGKey(0)

        log_dict = {}
        for i in range(levels.wall_map.shape[0]):

            level = jax.tree.map(lambda x: x[i], levels)
            # reset env to level and then reset to get state -> print state.
            # img = env.render(env_params, t)
            _, state = env.reset_env_to_level(
                rng, level, env.default_params)
            grid = np.asarray(state.env_state.maze_map)
            img = OvercookedVisualizer._render_grid(
                grid,
                tile_size=32,
                highlight_mask=None,
                agent_dir_idx=state.env_state.agent_dir_idx,
                agent_inv=state.env_state.agent_inv
            )
            log_dict.update({f"images/{i}_level": wandb.Image(np.array(img))})
        wandb.log(log_dict, step=step)

    def train_loop(
        meta_state: tuple, outer_idx,
    ):

        def _sample_learnability_buffer(rng, train_state):
            update_step = meta_state[-1]

            def _batch_step(unused, batch_rng):
                parner_rng, level_rng, rollout_rng = jax.random.split(
                    batch_rng, 3)
                # sample rulesets
                parner_rng = jax.random.split(
                    parner_rng, num=config.sfl_batch_size)
                level_rng = jax.random.split(
                    level_rng, num=config.sfl_batch_size)
                level_params = jax.vmap(sample_random_level)(
                    level_rng)

                rollout_rng = jax.random.split(
                    rollout_rng, num=config.sfl_batch_size)
                rollout_stats, last_state = jax.vmap(rollout_nsteps, in_axes=(0, None, 0, None, None, None, None, None, None))(
                    rollout_rng,
                    env,
                    level_params,
                    train_state,
                    jnp.zeros((config.num_agents, config.gru_hidden_dim)),
                    env.default_params.max_steps * config.sfl_rollout_factor,
                    config.sfl_rollout_factor,
                    config.sample_with_dense_rewards,
                    jnp.maximum(0.0, 1.0 - (update_step /
                                config.num_reward_shaping_updates))
                )
                return None, (level_params, rollout_stats)

            batch_rng = jax.random.split(rng, num=config.sfl_num_batches)
            _, (level_params, rollout_stats) = jax.lax.scan(
                _batch_step, None, batch_rng)

            if config.learnability_function == "gaussian-weighted-standarddev":
                mean_return = jnp.mean(
                    rollout_stats.episode_return, axis=-1)  # [2]
                std_return = jnp.std(
                    rollout_stats.episode_return, axis=-1)   # [2]
                # Global stats
                global_mean = jnp.mean(mean_return)
                global_std = jnp.std(mean_return) + 1e-8  # prevent div-by-zero
                # Gaussian weight centered on global mean
                gaussian_weight = jax.scipy.stats.norm.pdf(
                    mean_return, loc=global_mean, scale=global_std)
                # Generalised learnability
                # shape: [2]
                learnability = (std_return * gaussian_weight).squeeze()
            elif config.learnability_function == "variance":
                learnability = jnp.var(
                    rollout_stats.episode_return, axis=-1).squeeze()
            elif config.learnability_function == "variance-x-mean":
                mean_return = jnp.mean(rollout_stats.episode_return, axis=-1)
                std_return = jnp.std(rollout_stats.episode_return, axis=-1)
                learnability = (std_return * mean_return).squeeze()
            elif config.learnability_function == "mean-return":
                mean_return = jnp.mean(rollout_stats.episode_return, axis=-1)
                learnability = mean_return.squeeze()
            elif config.learnability_function == "success-rate-over-global-median":
                returns = rollout_stats.episode_return
                global_median = jnp.median(returns)
                success = returns > global_median
                success_rate = jnp.mean(
                    success, axis=-1).squeeze()
                learnability = success_rate * (1 - success_rate)
            elif config.learnability_function == "coefficent-of-variation":
                mean_return = jnp.mean(rollout_stats.episode_return, axis=-1)
                std_return = jnp.std(rollout_stats.episode_return, axis=-1)
                learnability = (std_return / (mean_return + 1e-8)).squeeze()
            elif config.learnability_function == "entropy":
                returns = rollout_stats.episode_return  # shape [2, N]
                hist_bins = 10
                hist_range = (jnp.min(returns), jnp.max(returns))
                hist, bin_edges = jnp.histogram(
                    returns, bins=hist_bins, range=hist_range, axis=-1, density=True)
                entropy = -jnp.sum(hist * jnp.log(hist + 1e-8), axis=-1)
                learnability = entropy.squeeze()
            elif config.learnability_function == "cole-rank-based":
                mean_return = jnp.mean(rollout_stats.episode_return, axis=-1)
                # higher ranks = harder partners
                ranks = jnp.argsort(jnp.argsort(-mean_return))
                learnability = (ranks + 1).astype(jnp.float32)
                learnability = learnability.squeeze()
            elif config.learnability_function == "cole-inverse-mean-return":
                mean_return = jnp.mean(rollout_stats.episode_return, axis=-1)
                learnability = 1.0 / (mean_return + 1e-8)
                learnability = learnability.squeeze()
            elif config.learnability_function == "cv-squared":
                mean_return = jnp.mean(rollout_stats.episode_return, axis=-1)
                var_return = jnp.var(rollout_stats.episode_return, axis=-1)
                learnability = (
                    var_return / (mean_return + 1e-8) ** 2).squeeze()

            print('rollout stats', rollout_stats)
            flat_level_params = jax.tree.map(
                lambda x: x.reshape((-1,) + x.shape[2:]), level_params)

            top_learnability = jnp.argsort(
                learnability)[-config.sfl_buffer_size:]
            top_level_params = jax.tree.map(
                lambda x: x.at[top_learnability].get(), flat_level_params)

            info = {
                "buffer_learnability_scores": learnability.at[top_learnability].get(),
                # "top_gaussian_weight": gaussian_weight.at[top_learnability].get(),
                # "all_gaussian_weight": gaussian_weight,
            }

            return (top_level_params), info

        def _meta_step(meta_state, update_idx):
            rng, train_state, sfl_buffer, past_5_sa_pairs, learnability_info, update_steps = meta_state

            # sample rulesets for this meta update
            rng, _rng1, _rng2, _rng3, _rng4 = jax.random.split(rng, num=5)

            parnter_gen_rng = jax.random.split(
                _rng1, num=config.sfl_num_envs_to_generate)
            level_gen_rng = jax.random.split(
                _rng4, num=config.sfl_num_envs_to_generate)
            level_params_gen = jax.vmap(sample_random_level)(level_gen_rng)
            params_gen = level_params_gen

            # sample from sfl buffer
            partner_params_idxs = jax.random.randint(
                _rng2, (config.sfl_num_envs_to_sample,), 0, config.sfl_buffer_size)
            sampled_params = jax.tree.map(
                lambda x: x.at[partner_params_idxs].get(), sfl_buffer)

            params = jax.tree.map(lambda x, y: jnp.concatenate(
                [x, y], axis=0), params_gen, sampled_params)
            level_params = params

            reset_rng = jax.random.split(_rng3, num=config.num_envs_per_device)
            obsv, env_state = jax.vmap(
                env.reset_env_to_level, in_axes=(0, 0, None))(reset_rng, level_params, env.default_params)

            init_hstate = ScannedRNN.initialize_carry(
                config.num_actors, config.gru_hidden_dim)

            def _update_step(update_runner_state, _):
                # COLLECT TRAJECTORIES
                runner_state, update_steps, learnability_info = update_runner_state

                def _env_step(runner_state, unused):
                    train_state, env_state, last_obs, last_done, hstate, rng, update_step, epsilon_agent, past_5_sa_pairs = runner_state
                    rng, _rng = jax.random.split(rng)
                    obs_batch = batchify(
                        last_obs, env.agents, config.num_actors)

                    def get_e3t_action(args):
                        pi_ego, k, e3t_epsilon, bias_mask = args
                        pi_random = distrax.Categorical(probs=bias_mask)
                        pi_e3t_probs = (1 - e3t_epsilon) * \
                            pi_ego.probs + e3t_epsilon * pi_random.probs
                        pi_e3t = distrax.Categorical(probs=pi_e3t_probs)
                        sampled_a = pi_e3t.sample(seed=k)
                        log_prob_a = pi_e3t.log_prob(sampled_a)
                        entropy_a = pi_e3t.entropy()
                        return sampled_a, log_prob_a, entropy_a

                    def get_base_action(args):
                        pi_ego, k, e3t_epsilo, bias_mask = args
                        sampled_a = pi_ego.sample(seed=k)
                        log_prob_a = pi_ego.log_prob(sampled_a)
                        entropy_a = pi_ego.entropy()
                        return sampled_a, log_prob_a, entropy_a

                    batched_sa_pairs = batchify_nested_dics(
                        past_5_sa_pairs, env.agents, (1, config.num_actors, 5))
                    ac_in = (
                        obs_batch[np.newaxis, :],
                        last_done[np.newaxis, :],
                        batched_sa_pairs,
                    )
                    hstate, pi, value, pred_pi = network.apply(
                        train_state.params, hstate, ac_in)

                    uniform_random = jnp.ones(
                        (ACTION_SPACE_SIZE,)) / ACTION_SPACE_SIZE

                    e3t_action, e3t_log_prob, e3t_entropy = get_e3t_action(
                        (pi, _rng, 0.5, uniform_random[np.newaxis, np.newaxis, ...]))
                    base_action, base_log_prob, base_entropy = get_base_action(
                        (pi, _rng, 0.5, uniform_random[np.newaxis, np.newaxis, ...]))

                    epsilon_agent_both = jnp.concatenate(
                        [epsilon_agent, ~epsilon_agent], axis=0)
                    action = jnp.where(epsilon_agent_both,
                                       e3t_action, base_action)
                    action = action.squeeze()

                    log_prob = jnp.where(
                        epsilon_agent_both, e3t_log_prob, base_log_prob).squeeze()
                    entropy = jnp.where(epsilon_agent_both,
                                        e3t_entropy, base_entropy).squeeze()

                    env_act = unbatchify(
                        action, env.agents, config.num_envs, env.num_agents
                    )
                    env_act = {k: v.squeeze() for k, v in env_act.items()}

                    # Update state-action pairs
                    past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
                        'agent_0']['obs'].at[:, :-1, :].set(past_5_sa_pairs['agent_0']['obs'][:, 1:, :])
                    past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
                        'agent_0']['obs'].at[:, - 1, :].set(last_obs['agent_0'])
                    past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
                        'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
                    past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
                        'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

                    past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
                        'agent_1']['obs'].at[:, :-1, :].set(past_5_sa_pairs['agent_1']['obs'][:, 1:, :])
                    past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
                        'agent_1']['obs'].at[:, -1, :].set(last_obs['agent_1'])
                    past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
                        'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
                    past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
                        'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

                    # STEP ENV
                    rng, _rng = jax.random.split(rng)
                    rng_step = jax.random.split(_rng, config.num_envs)
                    obsv, env_state, reward, done, info = jax.vmap(
                        env.step, in_axes=(0, 0, 0)
                    )(rng_step, env_state, env_act)
                    shaped_reward = info['shaped_reward']
                    reward_shaping_frac = jnp.maximum(
                        0.0, 1.0 - (update_step / config.num_reward_shaping_updates))
                    shaped_reward = unbatchify(
                        shaped_reward.transpose(1, 0), env.agents, config.num_envs, env.num_agents)
                    shaped_reward = {k: v.squeeze()
                                     for k, v in shaped_reward.items()}
                    reward = jax.tree.map(lambda x, y: x + y *
                                          reward_shaping_frac, reward, shaped_reward)

                    del info['shaped_reward']

                    info = jax.tree.map(lambda x: x.reshape(
                        (config.num_actors)), info)
                    done_batch = batchify(
                        done, env.agents, config.num_actors).squeeze()
                    other_action = jnp.concatenate([
                        env_act["agent_1"], env_act["agent_0"]
                    ], axis=-1)
                    transition = Transition(
                        jnp.tile(done["__all__"], env.num_agents),
                        last_done,
                        action.squeeze(),
                        value.squeeze(),
                        batchify(reward, env.agents,
                                 config.num_actors).squeeze(),
                        log_prob.squeeze(),
                        obs_batch,
                        info,
                        entropy,
                        hstate,
                        other_action.squeeze(),
                        batchify_nested_dics(
                            past_5_sa_pairs, env.agents, (config.num_actors, 5))
                    )
                    runner_state = (train_state, env_state, obsv,
                                    done_batch, hstate, rng, update_step, epsilon_agent, past_5_sa_pairs)
                    return runner_state, transition

                initial_hstate = runner_state[-3]
                (train_state, env_state, obsv,
                 done_batch, hstate, past_5_sa_pairs, rng) = runner_state
                # sample which agent we'll increase beta to
                # DO we really sample this here?
                epsilon_agent = jax.random.bernoulli(
                    rng, shape=(config.num_envs,))
                rng, _rng = jax.random.split(rng)

                runner_state = (train_state, env_state, obsv, done_batch,
                                hstate, rng, update_steps, epsilon_agent, past_5_sa_pairs)
                runner_state, traj_batch = jax.lax.scan(
                    _env_step, runner_state, None, config.num_steps_per_update
                )

                # CALCULATE ADVANTAGE
                train_state, env_state, last_obs, last_done, hstate, rng, update_steps, beta_agent, past_5_sa_pairs = runner_state
                last_obs_batch = batchify(
                    last_obs, env.agents, config.num_actors)
                batched_sa_pairs = batchify_nested_dics(
                    past_5_sa_pairs, env.agents, (1, config.num_actors, 5))

                ac_in = (
                    last_obs_batch[np.newaxis, :],
                    last_done[np.newaxis, :],
                    batched_sa_pairs,
                )

                _, _, last_val, _ = network.apply(
                    train_state.params, hstate, ac_in)
                last_val = last_val.squeeze()

                def _calculate_gae(traj_batch, last_val):
                    def _get_advantages(gae_and_next_value, transition):
                        gae, next_value = gae_and_next_value
                        done, value, reward = (
                            transition.done,
                            transition.value,
                            transition.reward,
                        )
                        delta = reward + config.gamma * \
                            next_value * (1 - done) - value
                        gae = (
                            delta
                            + config.gamma *
                            config.gae_lambda * (1 - done) * gae
                        )
                        return (gae, value), gae

                    _, advantages = jax.lax.scan(
                        _get_advantages,
                        (jnp.zeros_like(last_val), last_val),
                        traj_batch,
                        reverse=True,
                        unroll=16,
                    )
                    return advantages, advantages + traj_batch.value

                advantages, targets = _calculate_gae(traj_batch, last_val)

                # UPDATE NETWORK
                def _update_epoch(update_state, unused):
                    def _update_minbatch(train_state, batch_info):
                        init_hstate, traj_batch, advantages, targets = batch_info

                        def _loss_fn(params, traj_batch, gae, targets):
                            # RERUN NETWORK
                            _, pi, value, other_pi = network.apply(
                                params,
                                jax.tree.map(
                                    lambda h: h.squeeze(), init_hstate),
                                (traj_batch.obs, traj_batch.done,
                                 traj_batch.past_sa_pairs),
                            )
                            log_prob = pi.log_prob(traj_batch.action)

                            # CALCULATE VALUE LOSS
                            value_pred_clipped = traj_batch.value + (
                                value - traj_batch.value
                            ).clip(-config.clip_eps, config.clip_eps)
                            value_losses = jnp.square(value - targets)
                            value_losses_clipped = jnp.square(
                                value_pred_clipped - targets)
                            value_loss = (
                                0.5 * jnp.maximum(value_losses,
                                                  value_losses_clipped).mean()
                            )

                            # CALCULATE ACTOR LOSS
                            logratio = log_prob - traj_batch.log_prob
                            ratio = jnp.exp(logratio)
                            gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                            loss_actor1 = ratio * gae
                            loss_actor2 = (
                                jnp.clip(
                                    ratio,
                                    1.0 - config.clip_eps,
                                    1.0 + config.clip_eps,
                                )
                                * gae
                            )
                            loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                            loss_actor = loss_actor.mean()
                            entropy = pi.entropy().mean()

                            def calc_moa_loss(args):
                                other_pi, other_action = args
                                other_log_prob = other_pi.log_prob(
                                    other_action)
                                # NLL loss for other agent
                                other_loss = -other_log_prob
                                other_loss = other_loss.mean()
                                return other_loss

                            def dummy_moa_loss(x): return jnp.array(0.0)
                            moa_loss = jax.lax.cond(
                                config.other_agent_prediction, calc_moa_loss, dummy_moa_loss, (other_pi, traj_batch.other_action))

                            approx_kl = ((ratio - 1) - logratio).mean()
                            clip_frac = jnp.mean(
                                jnp.abs(ratio - 1) > config.clip_eps)

                            total_loss = (
                                loss_actor
                                + config.moa_coef * moa_loss
                                + config.vf_coef * value_loss
                                - config.ent_coef * entropy
                            )
                            return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac)

                        grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                        total_loss, grads = grad_fn(
                            train_state.params, traj_batch, advantages, targets
                        )
                        train_state = train_state.apply_gradients(grads=grads)
                        return train_state, total_loss

                    (
                        train_state,
                        init_hstate,
                        traj_batch,
                        advantages,
                        targets,
                        rng,
                    ) = update_state
                    rng, _rng = jax.random.split(rng)
                    batch_size = config.minibatch_size * config.num_minibatches
                    assert (
                        batch_size == config.num_steps_per_update * config.num_actors
                    ), "batch size must be equal to number of steps * number of actors per population member"
                    permutation = jax.random.permutation(_rng, batch_size)

                    init_hstate = jax.tree.map(lambda h: jnp.reshape(
                        h, (1, config.num_actors, -1)), init_hstate)
                    batch = (
                        init_hstate,
                        traj_batch,
                        advantages.squeeze(),
                        targets.squeeze(),
                    )
                    permutation = jax.random.permutation(
                        _rng, config.num_actors)

                    shuffled_batch = jax.tree_util.tree_map(
                        lambda x: jnp.take(x, permutation, axis=1), batch
                    )

                    minibatches = jax.tree_util.tree_map(
                        lambda x: jnp.swapaxes(
                            jnp.reshape(
                                x,
                                [x.shape[0], config.num_minibatches, -1]
                                + list(x.shape[2:]),
                            ),
                            1,
                            0,
                        ),
                        shuffled_batch,
                    )
                    train_state, total_loss = jax.lax.scan(
                        _update_minbatch, train_state, minibatches
                    )
                    update_state = (
                        train_state,
                        jax.tree.map(lambda h: h.squeeze(), init_hstate),
                        traj_batch,
                        advantages,
                        targets,
                        rng,
                    )
                    return update_state, total_loss

                update_state = (
                    train_state,
                    initial_hstate,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                )
                update_state, loss_info = jax.lax.scan(
                    _update_epoch, update_state, None, config.update_epochs
                )
                train_state = update_state[0]

                # Run evals
                rng, _rng = jax.random.split(rng, 2)
                eval_rng = jax.random.split(_rng, 100)

                eval_returns, _ = jax.vmap(jax.vmap(
                    rollout, in_axes=(0, None, None, None, None, None)),
                    in_axes=(None, None, 0, None, None, None))(
                    eval_rng, env, eval_levels, network, train_state.params, config.gru_hidden_dim)

                if config.eval_against_pop:
                    rng, _rng = jax.random.split(rng, 2)
                    eval_rng = jax.random.split(_rng, 100)
                    # First vmap over partners and then run 100 games

                    cr_level = jax.tree.map(
                        lambda x: x[0], eval_levels
                    )
                    cr_pop_params = partner_pop_params[0]
                    eval_pop_returns_cr = jax.vmap(
                        rollout_both_ways,
                        in_axes=(None, None, None, None,
                                 None, None, 0, None, None)
                    )(
                        eval_rng, env, cr_level, network, train_state.params, partner_pop_actor,
                        cr_pop_params, config.gru_hidden_dim, pop_size[0])

                    fc_level = jax.tree.map(
                        lambda x: x[1], eval_levels
                    )
                    fc_pop_params = partner_pop_params[1]
                    eval_pop_returns_fc = jax.vmap(
                        rollout_both_ways,
                        in_axes=(None, None, None, None,
                                 None, None, 0, None, None)
                    )(
                        eval_rng, env, fc_level, network, train_state.params, partner_pop_actor,
                        fc_pop_params, config.gru_hidden_dim, pop_size[1])

                    croom_level = jax.tree.map(
                        lambda x: x[2], eval_levels
                    )
                    croom_pop_params = partner_pop_params[2]
                    eval_pop_returns_croom = jax.vmap(
                        rollout_both_ways,
                        in_axes=(None, None, None, None,
                                 None, None, 0, None, None)
                    )(
                        eval_rng, env, croom_level, network, train_state.params, partner_pop_actor,
                        croom_pop_params, config.gru_hidden_dim, pop_size[2])
                else:
                    eval_pop_returns_croom = jnp.asarray(0.0)
                    eval_pop_returns_fc = jnp.asarray(0.0)
                    eval_pop_returns_cr = jnp.asarray(0.0)

                level_params_to_log = jax.tree.map(
                    lambda x: x.at[:config.log_num_images].get(), level_params)
                # jax.experimental.io_callback(log_levels, None, timesteps_to_log, rulesets_to_log, env_params, update_idx)

                jax.lax.cond(
                    update_idx % config.log_images_update == 0,
                    lambda *_: jax.debug.callback(
                        log_levels, level_params_to_log, update_idx),
                    lambda *_: None,
                )

                metric = traj_batch.info
                metric = jax.tree.map(
                    lambda x: x.reshape(
                        (config.num_steps_per_update,
                         config.num_envs, env.num_agents)
                    ),
                    traj_batch.info,
                )

                metric = jax.tree.map(lambda x: x[-1, ...].mean(), metric)
                ratio_0 = loss_info[1][3].at[0, 0].get().mean()
                loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
                metric["loss"] = {
                    "total_loss": loss_info[0],
                    "value_loss": loss_info[1][0],
                    "actor_loss": loss_info[1][1],
                    "entropy": loss_info[1][2],
                    "ratio": loss_info[1][3],
                    "ratio_0": ratio_0,
                    "approx_kl": loss_info[1][4],
                    "clip_frac": loss_info[1][5],
                }
                metric["eval"] = {
                    "eval_sp_return_cr": eval_returns[0].mean(),
                    "eval_sp_return_fc": eval_returns[1].mean(),
                    "eval_sp_return_croom": eval_returns[2].mean(),
                    "eval_sp": eval_returns.mean(),
                    "eval_pop_returns_cr": eval_pop_returns_cr.mean(),
                    "eval_pop_returns_fc": eval_pop_returns_fc.mean(),
                    "eval_pop_returns_croom": eval_pop_returns_croom.mean(),
                    "eval_pop_returns": jnp.array([eval_pop_returns_cr.mean(), eval_pop_returns_fc.mean(), eval_pop_returns_croom.mean()]).mean()
                }
                # hist, _ = jnp.histogram(
                #     sfl_buffer.epsilon, bins=5, range=(0, 1))
                metric["partner_params"] = {
                    "buffer_learnability_scores": learnability_info["buffer_learnability_scores"].mean(),
                }
                metric["reward_shaping_factor"] = jnp.maximum(
                    0.0, 1.0 - (update_steps / config.num_reward_shaping_updates))
                rng = update_state[-1]

                def callback(metric):
                    wandb.log(metric)

                update_steps = update_steps + 1
                valid_mask = metric["returned_episode"]
                returns = metric["returned_episode_returns"]

                # Convert mask to float (1.0 where valid, 0.0 where not)
                mask = valid_mask.astype(jnp.float32)

                # Masked sum and count
                masked_sum = (returns * mask).sum()
                count = mask.sum()

                # Avoid divide-by-zero
                mean_return = jnp.where(count > 0, masked_sum / count, 0.0)
                metric["mean_return"] = mean_return
                metric = jax.tree.map(lambda x: x.mean(), metric)

                metric["update_step"] = update_steps
                metric["env_step"] = update_steps * \
                    config.num_steps_per_update * config.num_envs
                jax.debug.callback(callback, metric)

                runner_state = (train_state, env_state, last_obs,
                                last_done, hstate, past_5_sa_pairs, rng)
                return (runner_state, update_steps, learnability_info), metric

            runner_state = (
                train_state,
                env_state,
                obsv,
                jnp.zeros((config.num_actors), dtype=bool),
                init_hstate,
                past_5_sa_pairs,
                rng)
            update_state = (runner_state, update_idx, learnability_info)
            (runner_state, update_steps, learnability_info), loss_info = jax.lax.scan(
                _update_step, update_state, None, config.num_inner_updates)

            meta_state = (
                runner_state[-1], runner_state[0], sfl_buffer, past_5_sa_pairs, learnability_info, update_steps)
            return meta_state, loss_info

        rng, train_state, past_5_sa_pairs, update_steps = meta_state
        rng, _rng = jax.random.split(rng)
        sfl_buffer, learnability_info = _sample_learnability_buffer(
            _rng, train_state)

        def __buffer_callback(x):
            info, step = x
            wandb.log(info, step=step)

        inner_idx = jnp.arange(config.sfl_buffer_refresh_freq) + \
            (outer_idx)*config.sfl_buffer_refresh_freq
        rng, _rng = jax.random.split(rng)
        meta_state, loss_info = jax.lax.scan(
            _meta_step, (_rng, train_state, sfl_buffer, past_5_sa_pairs, learnability_info, update_steps), inner_idx, config.sfl_buffer_refresh_freq)
        return meta_state, (loss_info, learnability_info)

    return train_loop


def get_run_string(config: TrainConfig):
    return f"FF_RNN_SFLE3T_WBIAS_IPPO_{config.learnability_function}_OGC_{config.ogc_height}_{config.ogc_width}_{config.ogc_n_walls}"


@pyrallis.wrap()
def train(config: TrainConfig):
    ##### WANDB and other setup #####
    tags = [
        "FF",
        "RNN",
        "SFLE3T",
        "IPPO",
        "OGC",
        "W/Bias",
        config.learnability_function,
    ]

    group_string = get_run_string(config)
    run_string = f"{group_string}_SEED_{config.seed}"

    run = wandb.init(
        project=config.project,
        group=group_string,
        mode=config.mode,
        config=asdict(config),
        save_code=True,
        tags=tags,
    )

    if run.sweep_id is not None:
        run.name = run.sweep_id + "___" + run_string
    else:
        run.name = run.name + "___" + run_string

    print("XPID ID name:")
    print(run.name)
    print("-------------")

    #### Setup and check saving before training ####
    if config.checkpoint_path is not None:
        save_dir = os.path.join(config.checkpoint_path, run.name)
        # Make sure we can write the checkpoint later _before_ we wait 1 day for training!
        os.makedirs(save_dir, exist_ok=True)
        config_dict = asdict(config)
        with open(f"{save_dir}/config.pckl", 'wb') as f:
            pickle.dump(config_dict, f)

    env = OGC(width=config.ogc_width, height=config.ogc_height)
    env = AutoReplayWrapper(env)

    def linear_schedule(count):
        total_inner_updates = config.num_minibatches * \
            config.update_epochs * config.num_inner_updates
        frac = 1.0 - (count // total_inner_updates) / config.num_meta_updates
        return config.lr * frac

    rng = jax.random.PRNGKey(config.seed)

    # INIT NETWORK
    network = ActorCriticRNN(
        env.action_space("agent_0").n,
        gru_hidden_dim_size=config.gru_hidden_dim,
        fc_dim_size=config.fc_dim_size,
        embedding_layers=config.embedding_layers,
        actor_layers=config.actor_layers,
        critic_layers=config.critic_layers,
        other_agent_prediction=config.other_agent_prediction,
        use_layernorm=config.use_layernorm,
    )
    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, config.num_envs, 5, init_x.shape[0])),
            'action': jnp.zeros((1, config.num_envs, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, config.num_envs, 5, init_x.shape[0])),
            'action': jnp.zeros((1, config.num_envs, 5, 1))
        }
    }

    batched_sa_pairs = batchify_nested_dics(
        past_5_sa_pairs, env.agents, (1, config.num_actors, 5))

    init_x = (
        jnp.zeros(
            (1, config.num_actors, init_x.shape[0])
        ),
        jnp.zeros((1, config.num_actors)),
        batched_sa_pairs,
    )
    init_hstate = ScannedRNN.initialize_carry(
        config.num_actors, config.gru_hidden_dim)

    rng, _rng = jax.random.split(rng)
    network_params = network.init(_rng, init_hstate, init_x)

    if config.anneal_lr:
        tx = optax.chain(
            optax.clip_by_global_norm(config.max_grad_norm),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        tx = optax.chain(optax.clip_by_global_norm(
            config.max_grad_norm), optax.adam(config.lr, eps=1e-5))

    train_state = TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )

    eval_level_names = ["coord_ring", "forced_coord", "cramped_room_5_5"]

    if config.eval_against_pop:
        # We now need to evaluate with several partners on several layouts.
        eval_levels, eval_partners, eval_partner_params, eval_popsizes = [], [], [], []
        for layout_name in eval_level_names:

            eval_level = Level.from_layout_name(layout_name)

            path = Path("eval_populations/FF_BRDiv") / layout_name
            partner_pop_params = []
            for file in path.iterdir():
                if "param" in file.name:
                    params = pickle.load(open(file, "rb"))
                    partner_pop_params.append(params["actor_params"])

            path_config = Path("eval_populations/FF_BRDiv") / \
                layout_name / "config.pckl"

            with open(path_config, "rb") as f:
                other_config = pickle.load(f)

            partner_pop_params = jax.tree.map(
                lambda *x: jnp.stack(x), *partner_pop_params)
            partner_pop_actor = ActorWithConditionalCritic(
                env.action_space(env.agents[0]).n)

            pop_size = other_config["partner_pop_size"]

            eval_levels.append(eval_level)
            eval_partner_params.append(partner_pop_params)
            eval_popsizes.append(pop_size)

        eval_levels = Level.stack(eval_levels)
    else:
        eval_partner_params = None
        partner_pop_actor = None
        eval_popsizes = 0
        eval_levels = None

    sample_random_level = make_level_generator(
        width=env.width, height=env.height, n_walls=config.ogc_n_walls, heldout_set=eval_levels)

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn(
        config, env, network, sample_random_level,
        eval_levels, partner_pop_actor, eval_partner_params, eval_popsizes)
    jitted_update_step = jax.jit(_update_step)

    # INIT EVAL ROLLOUT FUNCTION
    jitted_rollout = rollout  # config is static

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, config.num_envs)
    levels = jax.vmap(sample_random_level)(reset_rng)
    obsv, _ = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
        reset_rng, levels, env.default_params)
    # Fill past s-a
    past_5_sa_pairs['agent_0']['obs'] = obsv[
        'agent_0'][:, None, :].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (config.num_envs, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obsv[
        'agent_1'][:, None, :].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (config.num_envs, 5)) * 4

    update_steps = 0
    # None is a placeholder for the SFL buffer
    meta_state = (rng, train_state, None, past_5_sa_pairs, update_steps)

    print(config.num_outer_steps)

    for i in range(config.num_outer_steps):
        print(i)
        meta_state, (loss_info, learnability_info) = jitted_update_step(
            (meta_state[0], meta_state[1], past_5_sa_pairs, meta_state[-1]), i)
        train_state = meta_state[1]

        params = train_state.params

        # Remarkably, saving is among the most expensive operations
        if config.checkpoint_path is not None:
            if (i % config.checkpoint_freq == 0 and i != 0) or i == config.num_outer_steps - 1:
                # env = make_env(
                #     "overcooked-v1", {"layout": config.layout_name, "random_reset": False})

                # network = ActorCriticRNN(
                #     env.action_space("agent_0").n,
                #     gru_hidden_dim_size=config.gru_hidden_dim,
                #     fc_dim_size=config.fc_dim_size,
                #     embedding_layers=config.embedding_layers,
                #     actor_layers=config.actor_layers,
                #     critic_layers=config.critic_layers,
                #     other_agent_prediction=config.other_agent_prediction,
                #     use_layernorm=config.use_layernorm,
                # )
                # total_r, total_l = jitted_rollout(
                #     rng, env, network, params, config.gru_hidden_dim)

                path = f"{save_dir}/"
                os.makedirs(path, exist_ok=True)
                payload = {"actor_params": params}
                pickle.dump(payload, open(
                    path + f"params_{i}.pt", "wb"))
                pickle.dump(payload, open(
                    path + f"params.pt", "wb"))
                print(
                    f"Saved params for agent with total reward {path}params_{i}.pt")

    return {"runner_state": meta_state, "metrics": loss_info}


if __name__ == '__main__':
    train()
