"""Evaluates two XPIDs together that share the same network architectures
"""
from ast import arg
from functools import partial
import os
import re
import argparse
import pickle
import glob
from typing import Type

import jax
import jax.numpy as jnp
import numpy as np
import glob
from flax import struct

from src.envs import make_env

from src.jaxzsc.jdpd.jdpdv3_ippo_ogc_w_bias_rnn import TrainConfig as TrainConfigOGC
from src.jaxzsc.jdpd.cec_ippo_ogc_w_bias_rnn import TrainConfig as TrainConfigRandomSP
from src.jaxzsc.brdiv.brdiv_ippo_ogc import TrainConfig as TrainConfigBRDiv
from src.jaxzsc.jdpd.sfl_e3t_ippo_ogc_w_bias_rnn import TrainConfig as TrainConfigSFLE3T
from src.jaxzsc.jdpd.dr_dr_ippo_ogc_w_bias_rnn import TrainConfig as TrainConfigDRDR

from src.agents.actors import ActorCriticRNN, ActorWithConditionalCritic, ScannedRNN
from src.agents.overcooked.agent_policy_wrappers import (
    OvercookedIndependentPolicyWrapper,
    OvercookedOnionPolicyWrapper,
    OvercookedPlatePolicyWrapper,
    OvercookedRandomPolicyWrapper,
    OvercookedStaticPolicyWrapper,
)
from src.envs.overcooked.augmented_layouts import augmented_layouts as overcooked_layouts

from src.envs.ogc.auto_replay_wrapper import AutoReplayWrapper
from src.envs.ogc.ogc import OGC, Level, make_level_generator


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def rollout_single_l(rng, env, level, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
        ac_in = (
            last_obs["agent_0"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )
        hstate, pi, _, _ = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_1"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action0,
            "agent_1": action1
        }

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset_env_to_level(key_r, level, env.default_params)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_0']['obs'] = obs[
        'agent_0'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_single_r(rng, env, level, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
        ac_in = (
            last_obs["agent_1"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_0"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action1,
            "agent_1": action0,
        }

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, - 1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset_env_to_level(key_r, level, env.default_params)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_1']['obs'] = obs['agent_1'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_both_ways(eval_rng, env, level, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize):
    eval_returns_l, _ = jax.vmap(rollout_single_l, in_axes=(0, None, None, None, None, None, None, None, None))(
        eval_rng, env, level, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)

    eval_returns_r, _ = jax.vmap(rollout_single_r, in_axes=(0, None, None, None, None, None, None, None, None))(
        eval_rng, env, level, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)
    return jnp.array([eval_returns_l, eval_returns_r]).mean()


# ----------- Rollout against hardcoded agents ------------


def rollout_vs_hardcoded(
        rng, env, level, network, params, init_hstate, agent_switch, hardcoded_partner):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, other_agent_state, past_5_sa_pairs, done = carry
        return (done != True).any()

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, other_agent_state, past_5_sa_pairs, done = carry
        rng, rng_action, rng_step = jax.random.split(rng, 3)

        if agent_switch:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
            ac_in = (
                last_obs["agent_0"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs
            )
            hstate, pi, value, _ = network.apply(params, hstate, ac_in)
            action = pi.sample(seed=rng_action).squeeze()
            action_other, other_agent_state = hardcoded_partner.get_action(
                params=None, obs=last_obs["agent_1"], done=done[0],
                avail_actions=None, hstate=other_agent_state, rng=None, env_state=env_state.env_state)
            env_act = {"agent_0": action, "agent_1": action_other}
        else:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
            ac_in = (
                last_obs["agent_1"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs
            )
            hstate, pi, value, _ = network.apply(params, hstate, ac_in)
            action = pi.sample(seed=rng_action).squeeze()
            action_other, other_agent_state = hardcoded_partner.get_action(
                params=None, obs=last_obs["agent_0"], done=done[0],
                avail_actions=None, hstate=other_agent_state, rng=None, env_state=env_state.env_state)
            env_act = {"agent_0": action_other, "agent_1": action}

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        stats = stats.replace(reward=stats.reward +
                              reward["agent_0"], length=stats.length + 1)
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate, other_agent_state,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset_env_to_level(key_r, level, env.default_params)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    other_agent_id = 1 if agent_switch else 0
    other_agent_state = hardcoded_partner.init_hstate(
        None, aux_info={"agent_id": other_agent_id})

    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, other_agent_state, past_5_sa_pairs, jnp.array([False, False]))
    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze()


def rollout_both_ways_vs_hardcoded(
        rng, env, level, network, params, init_hstate, hardcoded_partner):

    reward1 = rollout_vs_hardcoded(
        rng, env, level, network, params, init_hstate, True, hardcoded_partner
    )

    reward2 = rollout_vs_hardcoded(
        rng, env, level, network, params, init_hstate, False, hardcoded_partner
    )

    return jnp.array([reward1, reward2])


# ------------------- Main Script -------------------

def interquartile_mean_vec(x, axis=-1):
    """
    Compute IQM along specified axis of x (vectorized).

    Args:
        x: JAX array of shape (..., N)
        axis: axis along which to compute IQM (default: last axis)

    Returns:
        iqm: JAX array of shape (...), with IQM along the given axis.
    """
    # Sort along the axis
    x_sorted = jnp.sort(x, axis=axis)

    n = x.shape[axis]

    # Compute indices for 25th and 75th percentiles
    q1_idx = int(jnp.floor(0.25 * n))
    q3_idx = int(jnp.ceil(0.75 * n))

    # Slice the interquartile range
    slices = [slice(None)] * x.ndim
    slices[axis] = slice(q1_idx, q3_idx)
    interquartile_slice = x_sorted[tuple(slices)]

    # Mean over the axis
    iqm = jnp.mean(interquartile_slice, axis=axis)

    return iqm


def get_all_steps(base_xpid):
    xpid_seed0 = f"{base_xpid}_SEED_0"
    save_dir = f"checkpoints/{xpid_seed0}"

    checkpoint_files = glob.glob(os.path.join(save_dir, "params_*_*.pt"))
    steps = []
    for file in checkpoint_files:
        match = re.search(r"params_(\d+)_", file)
        if match:
            steps.append(int(match.group(1)))
    steps = sorted(steps)
    return steps


def get_last_step(base_xpid):
    xpid_seed0 = f"{base_xpid}_SEED_0"
    save_dir = f"checkpoints/{xpid_seed0}"
    file = os.path.join(save_dir, "params.pt")

    return [file]


def load_config(xpid):
    save_dir = f"checkpoints/{xpid}"
    with open(f"{save_dir}/config.pckl", "rb") as f:
        loaded_dict = pickle.load(f)

    print(loaded_dict)
    if "RandomSP" in xpid:
        config = TrainConfigRandomSP(**loaded_dict)
    elif "DPD" in xpid:
        try:
            config = TrainConfigOGC(**loaded_dict)
        except TypeError:
            config = TrainConfigDRDR(**loaded_dict)
    elif "SFLE3T" in xpid:
        config = TrainConfigSFLE3T(**loaded_dict)
    else:
        raise ValueError(f"Unknown config type for XPID: {xpid}")

    return config


def load_params_for_seed(base_xpid, seed, step):
    xpid_seed = f"{base_xpid}_SEED_{seed}"
    save_dir = f"checkpoints/{xpid_seed}"
    pattern = os.path.join(save_dir, f"params_{step}_*.pt")
    files = glob.glob(pattern)
    if not files:
        raise FileNotFoundError(
            f"No checkpoint found for step {step} in seed {seed}")
    # If multiple files match, take the first (can adjust if needed)
    param_file = files[0]
    with open(param_file, "rb") as f:
        params = pickle.load(f)["actor_params"]
    return params


def load_final_params_for_seed(base_xpid, seed):
    xpid_seed = f"{base_xpid}_SEED_{seed}"
    save_dir = f"checkpoints/{xpid_seed}"

    try:
        param_file = os.path.join(save_dir, "params.pt")
        with open(param_file, "rb") as f:
            params = pickle.load(f)["actor_params"]
    except FileNotFoundError:
        param_file = os.path.join(save_dir, f"params_seed{seed}.pt")
        with open(param_file, "rb") as f:
            params = pickle.load(f)["actor_params"]
    # print(param_file)
    return params


def parse_xpid(xpid):
    # Example: wg77pjuz___FF_RNN_E3T_IPPO_Overcooked_cramped_room_SEED_0
    # We assume the method comes after 'FF_RNN_' and before '_IPPO'
    m = re.search(r'FF_RNN_(\w+)_(.*)_Overcooked_(.*)_SEED', xpid)
    if not m:
        # might be BR string
        m = re.search(r'FF_(\w+)_(.*)_Overcooked_(.*)_FF_', xpid)
        if not m:
            raise ValueError("Could not parse XPID string")

    method = m.group(1)
    layout = m.group(3)
    return method, layout


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_xpid', type=str, required=True,
                        help='Base XPID without _SEED_*')
    parser.add_argument('--max_seed', type=int, required=True,
                        help='Max seed index (inclusive)')
    parser.add_argument('--eval_all_steps', action='store_true')
    # eval_all_steps
    args = parser.parse_args()

    base_xpid = args.base_xpid
    num_seeds = args.max_seed + 1

    method = base_xpid

    base_xpid_full = args.base_xpid
    if base_xpid_full.endswith('_SEED_0'):
        base_xpid = base_xpid_full[:-7]  # remove "_SEED_0"
    else:
        raise ValueError("Expected --base_xpid to end with '_SEED_0'")

    # Load config (assume identical across seeds)
    config = load_config(f"{base_xpid}_SEED_0")

    # Load BRDiv population
    # No single layout since we evaluate on 3
    layout_names = ["forced_coord", "cramped_room_5_5", "coord_ring"]

    for layout_name in layout_names:
        print(f"\nEvaluating on layout: {layout_name}")

        # Load BRDiv config + population
        brdiv_config_path = f"eval_populations/FF_BRDiv_OGC/{layout_name}/config.pckl"
        file_paths = glob.glob(
            f"eval_populations/FF_BRDiv_OGC/{layout_name}/params_*.pt")

        if len(file_paths) == 0:
            print(
                f"Warning: No BRDiv population found for layout {layout_name}. Skipping.")
            continue

        populations = []
        for path in file_paths:
            with open(path, 'rb') as f:
                data = pickle.load(f)
                populations.append(data)

        with open(brdiv_config_path, 'rb') as f:
            print(brdiv_config_path)
            config_brdiv = pickle.load(f)
        config_brdiv = TrainConfigBRDiv(**config_brdiv)

        stacked_population = jax.tree.map(
            lambda *arrays: jnp.stack(arrays, axis=0), *populations)
        stacked_population = stacked_population["actor_params"]

        # Prepare environment for this layout
        env = OGC(width=config.ogc_width, height=config.ogc_height)
        env = AutoReplayWrapper(env)

        level = Level.from_layout_name(layout_name)

        rng = jax.random.PRNGKey(0)

        network_ego = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=config.other_agent_prediction,
            use_layernorm=config.use_layernorm,
        )

        network_brdiv = ActorWithConditionalCritic(
            env.action_space("agent_1").n,
            activation=config_brdiv.activation,
        )

        rollout_both_ways_jit = jax.jit(
            rollout_both_ways, static_argnums=(1, 2, 3, 6, 7))

        # Hardcoded partners
        layout_obj = overcooked_layouts[layout_name]
        if layout_name == "forced_coord":
            disable_most_hardcoded_partners = True  # They were not built for forced coord
            hardcoded_partners = [
                OvercookedIndependentPolicyWrapper(  # Left agent
                    layout=layout_obj,  p_onion_on_counter=0.6, p_plate_on_counter=0.6),
                # OvercookedOnionPolicyWrapper(layout=layout_obj),
                # OvercookedPlatePolicyWrapper(layout=layout_obj),
                OvercookedRandomPolicyWrapper(layout=layout_obj),
                # OvercookedStaticPolicyWrapper(layout=layout_obj),
            ]
            rollout_vs_hardcoded_indp_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[0]
                        ), static_argnums=(1, 3))

            rollout_vs_hardcoded_rndm_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[1]
                        ), static_argnums=(1, 3))
        else:
            disable_most_hardcoded_partners = False
            hardcoded_partners = [
                OvercookedIndependentPolicyWrapper(layout=layout_obj),
                OvercookedOnionPolicyWrapper(layout=layout_obj),
                OvercookedPlatePolicyWrapper(layout=layout_obj),
                OvercookedRandomPolicyWrapper(layout=layout_obj),
                OvercookedStaticPolicyWrapper(layout=layout_obj),
            ]

            rollout_vs_hardcoded_indp_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[0]
                        ), static_argnums=(1, 3))

            rollout_vs_hardcoded_onion_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[1]
                        ), static_argnums=(1, 3))

            rollout_vs_hardcoded_plate_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[2]
                        ), static_argnums=(1, 3))

            rollout_vs_hardcoded_rndm_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[3]
                        ), static_argnums=(1, 3))

            rollout_vs_hardcoded_static_jit = jax.jit(
                partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[4]
                        ), static_argnums=(1, 3))

        if args.eval_all_steps:
            steps = get_all_steps(base_xpid)
        else:
            steps = get_last_step(base_xpid)

        for step in steps:

            params_list = []
            for seed in range(num_seeds):
                if args.eval_all_steps:
                    params = load_params_for_seed(base_xpid, seed, step)
                    params_list.append(params)
                else:
                    params = load_final_params_for_seed(base_xpid, seed)
                    params_list.append(params)

            ego_params_stacked = jax.tree.map(
                lambda *arrays: jnp.stack(arrays, axis=0), *params_list)

            # BRDiv evaluation
            rng, _rng = jax.random.split(rng, 2)
            eval_rng = jax.random.split(_rng, 100)
            # eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize
            # for p in params_list:
            rewards = jax.vmap(jax.vmap(
                rollout_both_ways,
                in_axes=(None, None, None, None, None, None, 0, None, None)
            ), in_axes=(None, None, None, None, 0, None, None, None, None)
            )(
                eval_rng, env, level, network_ego, ego_params_stacked, network_brdiv,
                stacked_population, config.gru_hidden_dim, config_brdiv.partner_pop_size
            )
            # print(eval_pop_returns.mean())

            rewards_mean = rewards.mean()
            rewards_std = rewards.std()
            brdiv_mean = rewards.mean(axis=0)
            brdiv_std = rewards.std(axis=0)
            # Lets not do that

            # print(f"BRDiv: mean={rewards_mean:.3f}, std={rewards_std:.3f}")

            # Hardcoded evaluations
            init_hstate = ScannedRNN.initialize_carry(1, config.gru_hidden_dim)

            _rng = jax.random.split(rng, 100)

            hardcoded = []

            hard_coded_rewards = jax.vmap(jax.vmap(
                rollout_vs_hardcoded_indp_jit,
                in_axes=(0, None, None, None, None, None)),
                in_axes=(None, None, None, None, 0, None)
            )(
                _rng, env, level, network_ego, ego_params_stacked, init_hstate
            )
            hard_coded_rewards_mean = hard_coded_rewards.mean()
            hard_coded_rewards_std = hard_coded_rewards.std()
            indp = hard_coded_rewards_mean
            indp_std = hard_coded_rewards_std

            hardcoded.append(hard_coded_rewards_mean)

            # print(f"Indp: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

            hard_coded_rewards = jax.vmap(jax.vmap(
                rollout_vs_hardcoded_rndm_jit,
                in_axes=(0, None, None, None, None, None)),
                in_axes=(None, None, None, None, 0, None)
            )(
                _rng, env, level, network_ego, ego_params_stacked, init_hstate
            )
            hard_coded_rewards_mean = hard_coded_rewards.mean()
            hard_coded_rewards_std = hard_coded_rewards.std()
            rndm = hard_coded_rewards_mean
            rndm_std = hard_coded_rewards_std

            hardcoded.append(hard_coded_rewards_mean)

            # print(f"Random: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

            if not disable_most_hardcoded_partners:
                hard_coded_rewards = jax.vmap(jax.vmap(
                    rollout_vs_hardcoded_onion_jit,
                    in_axes=(0, None, None, None, None, None)),
                    in_axes=(None, None, None, None, 0, None)
                )(
                    _rng, env, level, network_ego, ego_params_stacked, init_hstate
                )
                hard_coded_rewards_mean = hard_coded_rewards.mean()
                hard_coded_rewards_std = hard_coded_rewards.std()
                onion = hard_coded_rewards_mean
                onion_std = hard_coded_rewards_std

                hardcoded.append(hard_coded_rewards_mean)

                # print(f"Onion: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

                hard_coded_rewards = jax.vmap(jax.vmap(
                    rollout_vs_hardcoded_plate_jit,
                    in_axes=(0, None, None, None, None, None)),
                    in_axes=(None, None, None, None, 0, None)
                )(
                    _rng, env, level, network_ego, ego_params_stacked, init_hstate
                )
                hard_coded_rewards_mean = hard_coded_rewards.mean()
                hard_coded_rewards_std = hard_coded_rewards.std()
                plate = hard_coded_rewards_mean
                plate_std = hard_coded_rewards_std

                hardcoded.append(hard_coded_rewards_mean)

                # print(f"Plate: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

                hard_coded_rewards = jax.vmap(jax.vmap(
                    rollout_vs_hardcoded_static_jit,
                    in_axes=(0, None, None, None, None, None)),
                    in_axes=(None, None, None, None, 0, None)
                )(
                    _rng, env, level, network_ego, ego_params_stacked, init_hstate
                )
                hard_coded_rewards_mean = hard_coded_rewards.mean()
                hard_coded_rewards_std = hard_coded_rewards.std()
                static = hard_coded_rewards_mean
                static_std = hard_coded_rewards_std
                hardcoded.append(hard_coded_rewards_mean)

                # print(f"Static: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

            else:
                static, rndm, plate, onion = 0, 0, 0, 0
                static_std, rndm_std, plate_std, onion_std = 0, 0, 0, 0

            mean = (np.mean(np.array(hardcoded)) + brdiv_mean.mean()) / 2
            # print(f"Mean={mean}")

            hardcoded.extend(brdiv_mean.tolist())
            mean_all = np.array(hardcoded).mean()
            # print(f"Mean all={mean_all}")

            # Store means and stds in arrays
            if disable_most_hardcoded_partners:
                partner_means = np.array(
                    [indp, *brdiv_mean.tolist()])
                partner_stds = np.array(
                    [indp_std, *brdiv_std.tolist()])
            else:
                partner_means = np.array(
                    [indp, onion, plate, rndm, static, *brdiv_mean.tolist()])
                partner_stds = np.array(
                    [indp_std, onion_std, plate_std, rndm_std, static_std, *brdiv_std.tolist()])

            # Mean_all: simple average
            mean_all = partner_means.mean()

            # Std for Mean_all: propagate assuming independence
            mean_all_std = np.sqrt(
                (partner_stds**2).sum() / len(partner_stds)**2)

            # Mean: average between (hardcoded average) and brdiv_mean
            hardcoded_mean = partner_means[:-1].mean()
            hardcoded_std = np.sqrt(
                (partner_stds[:-1]**2).sum() / len(partner_stds[:-1])**2)

            mean = (hardcoded_mean + brdiv_mean.mean()) / 2

            # Std for Mean: propagate from hardcoded_mean and brdiv_mean assuming independence
            mean_std = np.sqrt((hardcoded_std**2 + brdiv_std.mean()**2) / 4)

            # Pretty print results
            # print(f"Mean = {mean:.3f} ± {mean_std:.3f}")
            # print(f"Mean_all = {mean_all:.3f} ± {mean_all_std:.3f}")

            # Print CSV line for Google Sheets
            # print("\nBase_XPID,Method,Layout,Mean_all,Mean_all_std,BRDiv,Indp,Onion,Plate,Random,Static,Mean,BRDiv_std,Indp_std,Onion_std,Plate_std,Random_std,Static_std,Mean_std")
            # print("\n")
            print(f"{args.base_xpid},{method},{layout_name},{mean_all:.3f},{mean_std:.3f},{brdiv_mean.mean():.3f},{indp:.3f},{onion:.3f},{plate:.3f},{rndm:.3f},{static:.3f},{mean:.3f},{brdiv_std.mean():.3f},{indp_std:.3f},{onion_std:.3f},{plate_std:.3f},{rndm_std:.3f},{static_std:.3f},{mean_all_std:.3f}")


if __name__ == '__main__':
    main()
