"""Evaluates two XPIDs together that share the same network architectures
"""
from functools import partial
import os
import re
import argparse
import pickle
import glob

import jax
import jax.numpy as jnp
import numpy as np
import glob
from flax import struct

from src.envs import make_env

from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import TrainConfig as TrainConfigDPD
from src.jaxzsc.e3t.e3t_ippo_overcooked_rnn import TrainConfig as TrainConfigE3T
from src.jaxzsc.sp.sp_ippo_overcooked_rnn import TrainConfig as TrainConfigSP
from src.jaxzsc.brdiv.brdiv_ippo_overcooked import TrainConfig as TrainConfigBRDiv
from src.jaxzsc.best_response.best_response_ippo_overcooked_rnn import TrainConfig as TrainConfigBR

from src.agents.actors import ActorCriticRNN, ActorWithConditionalCritic, ScannedRNN
from src.agents.overcooked.agent_policy_wrappers import (
    OvercookedIndependentPolicyWrapper,
    OvercookedOnionPolicyWrapper,
    OvercookedPlatePolicyWrapper,
    OvercookedRandomPolicyWrapper,
    OvercookedStaticPolicyWrapper,
)
from src.envs.overcooked.augmented_layouts import augmented_layouts as overcooked_layouts


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def rollout_single_l(rng, env, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
        ac_in = (
            last_obs["agent_0"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )
        hstate, pi, _, _ = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_1"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action0,
            "agent_1": action1
        }

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_single_r(rng, env, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
        ac_in = (
            last_obs["agent_1"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_0"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action1,
            "agent_1": action0,
        }

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, - 1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_1']['obs'] = obs['agent_1'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_both_ways(eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize):
    eval_returns_l, _ = jax.vmap(rollout_single_l, in_axes=(0, None, None, None, None, None, None, None))(
        eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)

    eval_returns_r, _ = jax.vmap(rollout_single_r, in_axes=(0, None, None, None, None, None, None, None))(
        eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)
    return jnp.array([eval_returns_l, eval_returns_r]).mean()

# ----------- Rollout against hardcoded agents ------------


def rollout_vs_hardcoded(
        rng, env, network, params, init_hstate, agent_switch, hardcoded_partner):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, other_agent_state, past_5_sa_pairs, done = carry
        return (done != True).any()

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, other_agent_state, past_5_sa_pairs, done = carry
        rng, rng_action, rng_step = jax.random.split(rng, 3)

        if agent_switch:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
            ac_in = (
                last_obs["agent_0"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs
            )
            hstate, pi, value, _ = network.apply(params, hstate, ac_in)
            action = pi.sample(seed=rng_action).squeeze()
            action_other, other_agent_state = hardcoded_partner.get_action(
                params=None, obs=last_obs["agent_1"], done=done[0],
                avail_actions=None, hstate=other_agent_state, rng=None, env_state=env_state.env_state)
            env_act = {"agent_0": action, "agent_1": action_other}
        else:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
            ac_in = (
                last_obs["agent_1"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs
            )
            hstate, pi, value, _ = network.apply(params, hstate, ac_in)
            action = pi.sample(seed=rng_action).squeeze()
            action_other, other_agent_state = hardcoded_partner.get_action(
                params=None, obs=last_obs["agent_0"], done=done[0],
                avail_actions=None, hstate=other_agent_state, rng=None, env_state=env_state.env_state)
            env_act = {"agent_0": action_other, "agent_1": action}

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        stats = stats.replace(reward=stats.reward +
                              reward["agent_0"], length=stats.length + 1)
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate, other_agent_state,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    other_agent_id = 1 if agent_switch else 0
    other_agent_state = hardcoded_partner.init_hstate(
        None, aux_info={"agent_id": other_agent_id})

    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, other_agent_state, past_5_sa_pairs, jnp.array([False, False]))
    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze()


def rollout_both_ways_vs_hardcoded(
        rng, env, network, params, init_hstate, hardcoded_partner):

    reward1 = rollout_vs_hardcoded(
        rng, env, network, params, init_hstate, True, hardcoded_partner
    )

    reward2 = rollout_vs_hardcoded(
        rng, env, network, params, init_hstate, False, hardcoded_partner
    )

    return jnp.array([reward1, reward2])


# ------------------- Main Script -------------------

def interquartile_mean_vec(x, axis=-1):
    """
    Compute IQM along specified axis of x (vectorized).

    Args:
        x: JAX array of shape (..., N)
        axis: axis along which to compute IQM (default: last axis)

    Returns:
        iqm: JAX array of shape (...), with IQM along the given axis.
    """
    # Sort along the axis
    x_sorted = jnp.sort(x, axis=axis)

    n = x.shape[axis]

    # Compute indices for 25th and 75th percentiles
    q1_idx = int(jnp.floor(0.25 * n))
    q3_idx = int(jnp.ceil(0.75 * n))

    # Slice the interquartile range
    slices = [slice(None)] * x.ndim
    slices[axis] = slice(q1_idx, q3_idx)
    interquartile_slice = x_sorted[tuple(slices)]

    # Mean over the axis
    iqm = jnp.mean(interquartile_slice, axis=axis)

    return iqm


def get_all_steps(base_xpid):
    xpid_seed0 = f"{base_xpid}_SEED_0"
    save_dir = f"checkpoints/{xpid_seed0}"

    checkpoint_files = glob.glob(os.path.join(save_dir, "params_*_*.pt"))
    steps = []
    for file in checkpoint_files:
        match = re.search(r"params_(\d+)_", file)
        if match:
            steps.append(int(match.group(1)))
    steps = sorted(steps)
    return steps


def get_last_step(base_xpid):
    xpid_seed0 = f"{base_xpid}_SEED_0"
    save_dir = f"checkpoints/{xpid_seed0}"
    file = os.path.join(save_dir, "params.pt")

    return [file]


def load_config(xpid):
    save_dir = f"checkpoints/{xpid}"
    with open(f"{save_dir}/config.pckl", "rb") as f:
        loaded_dict = pickle.load(f)

    if "E3T" in xpid:
        config = TrainConfigE3T(**loaded_dict)
    elif "DPD" in xpid:
        config = TrainConfigDPD(**loaded_dict)
    elif "SP" in xpid:
        config = TrainConfigSP(**loaded_dict)
    elif "BR" in xpid:
        config = TrainConfigBR(**loaded_dict)
    else:
        raise ValueError(f"Unknown config type for XPID: {xpid}")

    return config


def load_params_for_seed(base_xpid, seed, step):
    xpid_seed = f"{base_xpid}_SEED_{seed}"
    save_dir = f"checkpoints/{xpid_seed}"
    pattern = os.path.join(save_dir, f"params_{step}_*.pt")
    files = glob.glob(pattern)
    if not files:
        raise FileNotFoundError(
            f"No checkpoint found for step {step} in seed {seed}")
    # If multiple files match, take the first (can adjust if needed)
    param_file = files[0]
    with open(param_file, "rb") as f:
        params = pickle.load(f)["actor_params"]
    return params


def load_final_params_for_seed(base_xpid, seed):
    xpid_seed = f"{base_xpid}_SEED_{seed}"
    save_dir = f"checkpoints/{xpid_seed}"

    try:
        param_file = os.path.join(save_dir, "params.pt")
        with open(param_file, "rb") as f:
            params = pickle.load(f)["actor_params"]
    except FileNotFoundError:
        param_file = os.path.join(save_dir, f"params_seed{seed}.pt")
        with open(param_file, "rb") as f:
            params = pickle.load(f)["actor_params"]
    # print(param_file)
    return params


def parse_xpid(xpid):
    # Example: wg77pjuz___FF_RNN_E3T_IPPO_Overcooked_cramped_room_SEED_0
    # We assume the method comes after 'FF_RNN_' and before '_IPPO'
    m = re.search(r'FF_RNN_(\w+)_(.*)_Overcooked_(.*)_SEED', xpid)
    if not m:
        # might be BR string
        m = re.search(r'FF_(\w+)_(.*)_Overcooked_(.*)_FF_', xpid)
        if not m:
            raise ValueError("Could not parse XPID string")

    method = m.group(1)
    layout = m.group(3)
    return method, layout


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_xpid', type=str, required=True,
                        help='Base XPID without _SEED_*')
    parser.add_argument('--max_seed', type=int, required=True,
                        help='Max seed index (inclusive)')
    parser.add_argument('--eval_all_steps', action='store_true')
    # eval_all_steps
    args = parser.parse_args()

    base_xpid = args.base_xpid
    num_seeds = args.max_seed + 1

    base_xpid_full = args.base_xpid
    if base_xpid_full.endswith('_SEED_0'):
        base_xpid = base_xpid_full[:-7]  # remove "_SEED_0"
    else:
        raise ValueError("Expected --base_xpid to end with '_SEED_0'")

    # Load config (assume identical across seeds)
    config = load_config(f"{base_xpid}_SEED_0")

    # Load BRDiv population
    brdiv_config_path = f"eval_populations/FF_BRDiv/{config.layout_name}/config.pckl"
    file_paths = glob.glob(
        f"eval_populations/FF_BRDiv/{config.layout_name}/params_*.pt")

    method, layout = parse_xpid(args.base_xpid)

    populations = []
    for path in file_paths:
        with open(path, 'rb') as f:
            data = pickle.load(f)
            populations.append(data)
    with open(brdiv_config_path, 'rb') as f:
        config_brdiv = pickle.load(f)
    config_brdiv = TrainConfigBRDiv(**config_brdiv)

    stacked_population = jax.tree.map(
        lambda *arrays: jnp.stack(arrays, axis=0), *populations)
    stacked_population = stacked_population["actor_params"]

    # Prepare environment
    env = make_env("overcooked-v1",
                   {"layout": config.layout_name, "random_reset": False})
    rng = jax.random.PRNGKey(0)

    # Create networks
    if "SP" in base_xpid:
        network_ego = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=False,
            use_layernorm=False,
        )
    elif "BR" in base_xpid:
        network_ego = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=False,
            use_layernorm=config.use_layernorm,
        )
    else:
        network_ego = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=config.other_agent_prediction,
            use_layernorm=config.use_layernorm,
        )

    network_brdiv = ActorWithConditionalCritic(
        env.action_space("agent_1").n,
        activation=config_brdiv.activation,
    )

    rollout_both_ways_jit = jax.jit(
        rollout_both_ways, static_argnums=(1, 2, 3, 6, 7))

    # Hardcoded partners
    layout_obj = overcooked_layouts[config.layout_name]
    if config.layout_name == "forced_coord":
        disable_most_hardcoded_partners = True  # They were not built for forced coord
        hardcoded_partners = [
            OvercookedIndependentPolicyWrapper(  # Left agent
                layout=layout_obj,  p_onion_on_counter=0.6, p_plate_on_counter=0.6),
            # OvercookedOnionPolicyWrapper(layout=layout_obj),
            # OvercookedPlatePolicyWrapper(layout=layout_obj),
            OvercookedRandomPolicyWrapper(layout=layout_obj),
            # OvercookedStaticPolicyWrapper(layout=layout_obj),
        ]
        rollout_vs_hardcoded_indp_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[0]
                    ), static_argnums=(1, 2))

        rollout_vs_hardcoded_rndm_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[1]
                    ), static_argnums=(1, 2))
    else:
        disable_most_hardcoded_partners = False
        hardcoded_partners = [
            OvercookedIndependentPolicyWrapper(layout=layout_obj),
            OvercookedOnionPolicyWrapper(layout=layout_obj),
            OvercookedPlatePolicyWrapper(layout=layout_obj),
            OvercookedRandomPolicyWrapper(layout=layout_obj),
            OvercookedStaticPolicyWrapper(layout=layout_obj),
        ]

        rollout_vs_hardcoded_indp_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[0]
                    ), static_argnums=(1, 2))

        rollout_vs_hardcoded_onion_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[1]
                    ), static_argnums=(1, 2))

        rollout_vs_hardcoded_plate_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[2]
                    ), static_argnums=(1, 2))

        rollout_vs_hardcoded_rndm_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[3]
                    ), static_argnums=(1, 2))

        rollout_vs_hardcoded_static_jit = jax.jit(
            partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=hardcoded_partners[4]
                    ), static_argnums=(1, 2))

    if args.eval_all_steps:
        steps = get_all_steps(base_xpid)
    else:
        steps = get_last_step(base_xpid)

    for step in steps:
        # print(f"\n--- Evaluating step {step} ---")

        params_list = []
        for seed in range(num_seeds):
            # print(seed)
            if args.eval_all_steps:
                params = load_params_for_seed(base_xpid, seed, step)
                params_list.append(params)
            else:
                params = load_final_params_for_seed(base_xpid, seed)
                params_list.append(params)

        ego_params_stacked = jax.tree.map(
            lambda *arrays: jnp.stack(arrays, axis=0), *params_list)

        # BRDiv evaluation
        rng, _rng = jax.random.split(rng, 2)
        eval_rng = jax.random.split(_rng, 100)
        # eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize
        # for p in params_list:
        rewards = jax.vmap(jax.vmap(
            rollout_both_ways,
            in_axes=(None, None, None, None, None, 0, None, None)
        ), in_axes=(None, None, None, 0, None, None, None, None)
        )(
            eval_rng, env, network_ego, ego_params_stacked, network_brdiv,
            stacked_population, config.gru_hidden_dim, config_brdiv.partner_pop_size
        )

        rewards_mean = rewards.mean()
        rewards_std = rewards.std()
        brdiv_mean = rewards.mean(axis=0)
        brdiv_std = rewards.std(axis=0)

        # Hardcoded evaluations
        init_hstate = ScannedRNN.initialize_carry(1, config.gru_hidden_dim)

        _rng = jax.random.split(rng, 100)

        hardcoded = []

        hard_coded_rewards = jax.vmap(jax.vmap(
            rollout_vs_hardcoded_indp_jit,
            in_axes=(0, None, None, None, None)),
            in_axes=(None, None, None, 0, None)
        )(
            _rng, env, network_ego, ego_params_stacked, init_hstate
        )
        hard_coded_rewards_mean = hard_coded_rewards.mean()
        hard_coded_rewards_std = hard_coded_rewards.mean(axis=(1, 2)).std()
        indp = hard_coded_rewards_mean
        indp_std = hard_coded_rewards_std

        hardcoded.append(hard_coded_rewards_mean)

        # print(f"Indp: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

        hard_coded_rewards = jax.vmap(jax.vmap(
            rollout_vs_hardcoded_rndm_jit,
            in_axes=(0, None, None, None, None)),
            in_axes=(None, None, None, 0, None)
        )(
            _rng, env, network_ego, ego_params_stacked, init_hstate
        )
        hard_coded_rewards_mean = hard_coded_rewards.mean()
        hard_coded_rewards_std = hard_coded_rewards.mean(axis=(1, 2)).std()
        rndm = hard_coded_rewards_mean
        rndm_std = hard_coded_rewards_std

        hardcoded.append(hard_coded_rewards_mean)

        # print(f"Random: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

        if not disable_most_hardcoded_partners:
            hard_coded_rewards = jax.vmap(jax.vmap(
                rollout_vs_hardcoded_onion_jit,
                in_axes=(0, None, None, None, None)),
                in_axes=(None, None, None, 0, None)
            )(
                _rng, env, network_ego, ego_params_stacked, init_hstate
            )
            hard_coded_rewards_mean = hard_coded_rewards.mean()
            hard_coded_rewards_std = hard_coded_rewards.mean(axis=(1, 2)).std()
            onion = hard_coded_rewards_mean
            onion_std = hard_coded_rewards_std

            hardcoded.append(hard_coded_rewards_mean)

            # print(f"Onion: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

            hard_coded_rewards = jax.vmap(jax.vmap(
                rollout_vs_hardcoded_plate_jit,
                in_axes=(0, None, None, None, None)),
                in_axes=(None, None, None, 0, None)
            )(
                _rng, env, network_ego, ego_params_stacked, init_hstate
            )
            hard_coded_rewards_mean = hard_coded_rewards.mean()
            hard_coded_rewards_std = hard_coded_rewards.mean(axis=(1, 2)).std()
            plate = hard_coded_rewards_mean
            plate_std = hard_coded_rewards_std

            hardcoded.append(hard_coded_rewards_mean)

            # print(f"Plate: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

            hard_coded_rewards = jax.vmap(jax.vmap(
                rollout_vs_hardcoded_static_jit,
                in_axes=(0, None, None, None, None)),
                in_axes=(None, None, None, 0, None)
            )(
                _rng, env, network_ego, ego_params_stacked, init_hstate
            )
            hard_coded_rewards_mean = hard_coded_rewards.mean()
            hard_coded_rewards_std = hard_coded_rewards.mean(axis=(1, 2)).std()
            static = hard_coded_rewards_mean
            static_std = hard_coded_rewards_std
            hardcoded.append(hard_coded_rewards_mean)

            # print(f"Static: mean={hard_coded_rewards_mean:.3f}, std={hard_coded_rewards_std:.3f}")

        else:
            static, rndm, plate, onion = 0, 0, 0, 0
            static_std, rndm_std, plate_std, onion_std = 0, 0, 0, 0

        mean = (np.mean(np.array(hardcoded)) + brdiv_mean.mean()) / 2
        # print(f"Mean={mean}")

        hardcoded.extend(brdiv_mean.tolist())
        mean_all = np.array(hardcoded).mean()
        # print(f"Mean all={mean_all}")

        # Store means and stds in arrays
        if disable_most_hardcoded_partners:
            partner_means = np.array(
                [indp, *brdiv_mean.tolist()])
            partner_stds = np.array(
                [indp_std, *brdiv_std.tolist()])
        else:
            partner_means = np.array(
                [indp, onion, plate, rndm, static, *brdiv_mean.tolist()])
            partner_stds = np.array(
                [indp_std, onion_std, plate_std, rndm_std, static_std, *brdiv_std.tolist()])

        # Mean_all: simple average
        mean_all = partner_means.mean()

        # Std for Mean_all: propagate assuming independence
        mean_all_std = np.sqrt((partner_stds**2).sum() / len(partner_stds)**2)

        # Mean: average between (hardcoded average) and brdiv_mean
        hardcoded_mean = partner_means[:-1].mean()
        hardcoded_std = np.sqrt(
            (partner_stds[:-1]**2).sum() / len(partner_stds[:-1])**2)

        mean = (hardcoded_mean + brdiv_mean.mean()) / 2

        # Std for Mean: propagate from hardcoded_mean and brdiv_mean assuming independence
        mean_std = np.sqrt((hardcoded_std**2 + brdiv_std.mean()**2) / 4)

        # Pretty print results
        # print(f"Mean = {mean:.3f} ± {mean_std:.3f}")
        # print(f"Mean_all = {mean_all:.3f} ± {mean_all_std:.3f}")

        # Print CSV line for Google Sheets
        # print("\nBase_XPID,Method,Layout,Mean_all,Mean_all_std,BRDiv,Indp,Onion,Plate,Random,Static,Mean,BRDiv_std,Indp_std,Onion_std,Plate_std,Random_std,Static_std,Mean_std")
        # print("\n")
        print(f"{args.base_xpid},{method},{layout},{mean_all:.3f},{mean_std:.3f},{brdiv_mean.mean():.3f},{indp:.3f},{onion:.3f},{plate:.3f},{rndm:.3f},{static:.3f},{mean:.3f},{brdiv_std.mean():.3f},{indp_std:.3f},{onion_std:.3f},{plate_std:.3f},{rndm_std:.3f},{static_std:.3f},{mean_all_std:.3f}")


if __name__ == '__main__':
    main()
