"""Evaluate an LBF agent (XPID) against a BRDiv population and hardcoded LBF partners."""
import os
import re
import glob
import argparse
import pickle
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from flax import struct

from src.envs import make_env
from src.envs.log_wrapper import LogWrapper
from src.agents.actors import ActorCriticRNN, ActorWithConditionalCritic, ScannedRNN

# LBF hardcoded partners
from src.agents.lbf.agent_policy_wrappers import (
    LBFRandomPolicyWrapper,
    LBFSequentialFruitPolicyWrapper,
)

# TrainConfigs for LBF runs
from src.jaxzsc.e3t.e3t_ippo_lbf_rnn import TrainConfig as TrainConfigE3T
from src.jaxzsc.dpd.dpd_ippo_lbf_w_bias_rnn import TrainConfig as TrainConfigDPD


# ---------------- utils ----------------

class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def batchify_nested_dics(x: dict, agent_list, shape):
    data = [x[a] for a in agent_list]
    tree = jax.tree.map(lambda *v: jnp.stack(v), *data)
    return jax.tree.map(lambda x: x.reshape((*shape, -1)), tree)


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def mask_and_norm(probs, mask, eps=1e-8):
    p = probs * mask
    z = jnp.sum(p, axis=-1, keepdims=True)
    uniform_valid = mask / jnp.maximum(mask.sum(-1, keepdims=True), 1.0)
    return jnp.where(z > 0, p / jnp.maximum(z, eps), uniform_valid)


def cvar(returns: np.ndarray, alpha=0.1):
    """Compute CVaR at level alpha from a list or array of returns."""
    returns_sorted = np.sort(returns)
    cutoff = int(np.ceil(alpha * len(returns_sorted)))
    return float(np.mean(returns_sorted[:cutoff]))


# ---------------- rollouts (LBF) ----------------

def rollout_single_l(rng, env, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry
        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
        avail = env.get_avail_actions(env_state.env_state)
        ac_in = (
            last_obs["agent_0"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past,
            avail["agent_0"].astype(jnp.float32)[np.newaxis, np.newaxis, :],
        )
        hstate, pi, _, _ = network.apply(params, hstate, ac_in)
        a0 = pi.sample(seed=rng_action).squeeze()

        # partner from population (conditional actor)
        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_1"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :],
                avail["agent_1"].astype(jnp.float32)[np.newaxis, :],
            ),
        )
        a1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {"agent_0": a0, "agent_1": a1}

        # update past 5 s-a pairs
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs['agent_0']['obs'].at[:, :-1].set(
            past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs['agent_0']['obs'].at[:, -
                                                                                 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs['agent_0']['action'].at[:, :-1].set(
            past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs['agent_0']['action'].at[:, -
                                                                                       1].set(env_act['agent_0'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)
        stats = stats.replace(reward=stats.reward +
                              reward["agent_0"], length=stats.length + 1)
        done = batchify(done, env.agents, 2)
        return (rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done.squeeze())

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)
    init_x = obs['agent_0'].reshape(-1).shape[0]

    past_5_sa_pairs = {'agent_0': {'obs': jnp.zeros(
        (1, 5, init_x)), 'action': jnp.zeros((1, 5, 1))}}
    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones((1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(1, hidden_size)
    init_carry = (rng, state, RolloutStats(), obs, init_hstate,
                  past_5_sa_pairs, jnp.array([False, False]))
    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_single_r(rng, env, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry
        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
        avail = env.get_avail_actions(env_state.env_state)
        ac_in = (
            last_obs["agent_1"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past,
            avail["agent_1"].astype(jnp.float32)[np.newaxis, np.newaxis, :],
        )
        hstate, pi, _, _ = network.apply(params, hstate, ac_in)
        a1 = pi.sample(seed=rng_action).squeeze()

        pi0, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_0"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :],
                avail["agent_0"].astype(jnp.float32)[np.newaxis, :],
            ),
        )
        a0 = pi0.sample(seed=rng_o_action).squeeze()

        env_act = {"agent_0": a0, "agent_1": a1}

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs['agent_1']['obs'].at[:, :-1].set(
            past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs['agent_1']['obs'].at[:, -
                                                                                 1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs['agent_1']['action'].at[:, :-1].set(
            past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs['agent_1']['action'].at[:, -
                                                                                       1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)
        stats = stats.replace(reward=stats.reward +
                              reward["agent_0"], length=stats.length + 1)
        done = batchify(done, env.agents, 2)
        return (rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done.squeeze())

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)
    init_x = obs['agent_1'].reshape(-1).shape[0]

    past_5_sa_pairs = {'agent_1': {'obs': jnp.zeros(
        (1, 5, init_x)), 'action': jnp.zeros((1, 5, 1))}}
    past_5_sa_pairs['agent_1']['obs'] = obs['agent_1'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones((1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(1, hidden_size)
    init_carry = (rng, state, RolloutStats(), obs, init_hstate,
                  past_5_sa_pairs, jnp.array([False, False]))
    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_both_ways(eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize):
    eval_returns_l, _ = jax.vmap(rollout_single_l, in_axes=(0, None, None, None, None, None, None, None))(
        eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize
    )
    eval_returns_r, _ = jax.vmap(rollout_single_r, in_axes=(0, None, None, None, None, None, None, None))(
        eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize
    )
    return jnp.array([eval_returns_l, eval_returns_r]).mean()


# ---------- hardcoded partners (LBF) ----------

def rollout_vs_hardcoded(rng, env, network, params, init_hstate, agent_switch, hardcoded_partner):
    """agent_switch=True => evaluate our policy as agent_0; else as agent_1."""
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, other_hstate, past_5_sa_pairs, done = carry
        return (done != True).any()

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, other_hstate, past_5_sa_pairs, done = carry
        rng, rng_action, rng_partner_action, rng_step = jax.random.split(
            rng, 4)

        avail = env.get_avail_actions(env_state.env_state)

        if agent_switch:
            in_past = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
            ac_in = (
                last_obs["agent_0"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past,
                avail["agent_0"].astype(jnp.float32)[
                    np.newaxis, np.newaxis, :],
            )
            hstate, pi, _, _ = network.apply(params, hstate, ac_in)
            a_ego = pi.sample(seed=rng_action).squeeze()

            a_partner, other_hstate = hardcoded_partner.get_action(
                params=None,
                obs=last_obs["agent_1"],
                done=done[0],
                avail_actions=avail["agent_1"],
                hstate=other_hstate,
                rng=rng_partner_action,
                env_state=env_state.env_state,
            )
            env_act = {"agent_0": a_ego, "agent_1": a_partner}
        else:
            in_past = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
            ac_in = (
                last_obs["agent_1"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past,
                avail["agent_1"].astype(jnp.float32)[
                    np.newaxis, np.newaxis, :],
            )
            hstate, pi, _, _ = network.apply(params, hstate, ac_in)
            a_ego = pi.sample(seed=rng_action).squeeze()

            a_partner, other_hstate = hardcoded_partner.get_action(
                params=None,
                obs=last_obs["agent_0"],
                done=done[0],
                avail_actions=avail["agent_0"],
                hstate=other_hstate,
                rng=rng_partner_action,
                env_state=env_state.env_state,
            )
            env_act = {"agent_0": a_partner, "agent_1": a_ego}

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)

        # update past s-a
        for aid in ["agent_0", "agent_1"]:
            past_5_sa_pairs[aid]['obs'] = past_5_sa_pairs[aid]['obs'].at[:, :-1].set(
                past_5_sa_pairs[aid]['obs'][:, 1:])
            past_5_sa_pairs[aid]['obs'] = past_5_sa_pairs[aid]['obs'].at[:, -
                                                                         1].set(last_obs[aid])
            past_5_sa_pairs[aid]['action'] = past_5_sa_pairs[aid]['action'].at[:, :-1].set(
                past_5_sa_pairs[aid]['action'][:, 1:])
            past_5_sa_pairs[aid]['action'] = past_5_sa_pairs[aid]['action'].at[:, -
                                                                               1].set(env_act[aid])

        stats = stats.replace(reward=stats.reward +
                              reward["agent_0"], length=stats.length + 1)
        done = batchify(done, env.agents, 2)
        return (rng, env_state, stats, obsv, hstate, other_hstate, past_5_sa_pairs, done.squeeze())

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)
    init_x = obs['agent_0'].reshape(-1).shape[0]

    past_5_sa_pairs = {
        'agent_0': {'obs': jnp.zeros((1, 5, init_x)), 'action': jnp.zeros((1, 5, 1))},
        'agent_1': {'obs': jnp.zeros((1, 5, init_x)), 'action': jnp.zeros((1, 5, 1))}
    }
    for aid in ['agent_0', 'agent_1']:
        past_5_sa_pairs[aid]['obs'] = obs[aid][:, None].repeat(5, axis=1)
        past_5_sa_pairs[aid]['action'] = jnp.ones((1, 5)) * 4

    other_id = 1 if agent_switch else 0
    other_hstate = hardcoded_partner.init_hstate(
        None, aux_info={"agent_id": other_id})

    init_carry = (rng, state, RolloutStats(), obs, init_hstate,
                  other_hstate, past_5_sa_pairs, jnp.array([False, False]))
    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_both_ways_vs_hardcoded(rng, env, network, params, init_hstate, hardcoded_partner):
    r1, l1 = rollout_vs_hardcoded(
        rng, env, network, params, init_hstate, True, hardcoded_partner)
    r2, l2 = rollout_vs_hardcoded(
        rng, env, network, params, init_hstate, False, hardcoded_partner)
    return jnp.array([r1, r2]), jnp.array([l1, l2])


# ---------------- loading / parsing ----------------

def get_all_steps(base_xpid):
    xpid_seed0 = f"{base_xpid}_SEED_0"
    save_dir = f"checkpoints/{xpid_seed0}"
    files = glob.glob(os.path.join(save_dir, "params_*_*.pt"))
    steps = []
    for f in files:
        m = re.search(r"params_(\d+)_", f)
        if m:
            steps.append(int(m.group(1)))
    return sorted(steps)


def get_last_step(base_xpid):
    xpid_seed0 = f"{base_xpid}_SEED_0"
    return [os.path.join("checkpoints", xpid_seed0, "params.pt")]


def load_config(xpid):
    save_dir = f"checkpoints/{xpid}"
    with open(os.path.join(save_dir, "config.pckl"), "rb") as f:
        loaded = pickle.load(f)

    # Decide by string: E3T or DPD (LBF only)
    if "E3T" in xpid:
        return TrainConfigE3T(**loaded)
    elif "DPD" in xpid:
        return TrainConfigDPD(**loaded)
    else:
        # Default to E3T if method omitted; adjust if needed
        return TrainConfigE3T(**loaded)


def load_params_for_seed(base_xpid, seed, step):
    xpid_seed = f"{base_xpid}_SEED_{seed}"
    save_dir = f"checkpoints/{xpid_seed}"
    pattern = os.path.join(save_dir, f"params_{step}_*.pt")
    files = glob.glob(pattern)
    if not files:
        raise FileNotFoundError(
            f"No checkpoint for step {step} in {xpid_seed}")
    with open(files[0], "rb") as f:
        return pickle.load(f)["actor_params"]


def load_final_params_for_seed(base_xpid, seed):
    xpid_seed = f"{base_xpid}_SEED_{seed}"
    save_dir = f"checkpoints/{xpid_seed}"
    try:
        with open(os.path.join(save_dir, "params.pt"), "rb") as f:
            return pickle.load(f)["actor_params"]
    except FileNotFoundError:
        with open(os.path.join(save_dir, f"params_seed{seed}.pt"), "rb") as f:
            return pickle.load(f)["actor_params"]


def parse_xpid(xpid: str) -> str:
    """
    Accepts strings like:
      - k3nvy8sr___FF_RNN_E3T_IPPO_LBF_eps_0.0_SEED_0
      - s8j9yev7___FF_RNN_DPD_WBIAS_IPPO_variance_K_512_LBF_SEED_0
    Returns the method token between 'FF_RNN_' and '_IPPO', e.g. 'E3T' or 'DPD_WBIAS'.
    """
    m = re.search(r'FF_RNN_(\w+)_(.*)_LBF_(.*)_SEED', xpid)
    if not m:
        m = re.search(r'FF_RNN_(\w+)_(\w+)_(.*)_LBF_(.*)', xpid)
        if not m:
            raise ValueError(f"Could not parse method from XPID: {xpid}")
    return m.group(1)


def _fmt(x):
    # consistent float formatting; leave strings/ints as-is
    if isinstance(x, (float, np.floating)):
        return f"{x:.3f}"
    if isinstance(x, (jnp.ndarray, np.ndarray)) and x.ndim == 0:
        return f"{float(x):.3f}"
    return str(x)


def print_two_line_csv(
    xpid: str,
    method: str,
    step: str,
    brdiv_mean_overall: float,
    brdiv_std_overall: float,
    hardcoded_means: dict,   # e.g. {"random": 0.167, "seq_lexi": 0.287, ...}
    hardcoded_stds: dict,    # e.g. {"random": 0.142, "seq_lexi": 0.228, ...}
    mean_all: float,
    mean_all_std: float,
    mean_combo: float,
    mean_combo_std: float,
    order: list[str],        # column order for hardcoded partners
):
    header = (
        ["xpid", "method", "step",
         "Mean_all", "Mean_all_std",
         "BRDiv_overall", "BRDiv_overall_std"]
        + [f"Hardcoded_{k}" for k in order]
        + [f"Hardcoded_std_{k}" for k in order]
        + ["Mean_combo", "Mean_combo_std"]
    )
    values = (
        [xpid, method, step,
         _fmt(mean_all), _fmt(mean_all_std),
         _fmt(brdiv_mean_overall), _fmt(brdiv_std_overall)]
        + [_fmt(hardcoded_means[k]) for k in order]
        + [_fmt(hardcoded_stds[k]) for k in order]
        + [_fmt(mean_combo), _fmt(mean_combo_std)]
    )
    print(",".join(header))
    print(",".join(values))


# ---------------- main ----------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_xpid', type=str, required=True,
                        help="Base XPID ending with _SEED_0")
    parser.add_argument('--max_seed', type=int, required=True,
                        help='Max seed index (inclusive)')
    parser.add_argument('--eval_all_steps', action='store_true')
    args = parser.parse_args()

    if not args.base_xpid.endswith('_SEED_0'):
        raise ValueError("Expected --base_xpid to end with '_SEED_0'")
    base_xpid = args.base_xpid[:-7]  # strip suffix

    # Config for LBF
    config = load_config(f"{base_xpid}_SEED_0")
    method = parse_xpid(args.base_xpid)

    # Load BRDiv population (LBF)
    pop_dir = "eval_populations/FF_BRDiv_LBF"
    with open(os.path.join(pop_dir, "config.pckl"), "rb") as f:
        brdiv_cfg = pickle.load(f)  # has 'partner_pop_size'
    pop_params = []
    for p in sorted(os.listdir(pop_dir)):
        if "param" in p:
            with open(os.path.join(pop_dir, p), "rb") as f:
                pop_params.append(pickle.load(f)["actor_params"])
    partner_pop_params = jax.tree.map(lambda *x: jnp.stack(x), *pop_params)
    partner_pop_actor = ActorWithConditionalCritic(
        6)  # LBF action space size = 6
    pop_size = brdiv_cfg["partner_pop_size"]

    # Environment
    env = make_env("lbf")
    env = LogWrapper(env, replace_info=False)
    rng = jax.random.PRNGKey(0)

    # Ego network (shared architecture)
    network_ego = ActorCriticRNN(
        env.action_space("agent_0").n,
        gru_hidden_dim_size=config.gru_hidden_dim,
        fc_dim_size=config.fc_dim_size,
        embedding_layers=config.embedding_layers,
        actor_layers=config.actor_layers,
        critic_layers=config.critic_layers,
        other_agent_prediction=getattr(config, "other_agent_prediction", True),
        use_layernorm=getattr(config, "use_layernorm", True),
        env_has_avail_actions=True,
    )

    rollout_both_ways_jit = jax.jit(
        rollout_both_ways, static_argnums=(1, 2, 3, 6, 7))

    # Hardcoded partners (LBF)
    hardcoded_specs = [
        ("random", LBFRandomPolicyWrapper()),
        ("seq_lexi", LBFSequentialFruitPolicyWrapper(
            ordering_strategy="lexicographic")),
        ("seq_revlexi", LBFSequentialFruitPolicyWrapper(
            ordering_strategy="reverse_lexicographic")),
        ("seq_col", LBFSequentialFruitPolicyWrapper(
            ordering_strategy="column_major")),
        ("seq_revcol", LBFSequentialFruitPolicyWrapper(
            ordering_strategy="reverse_column_major")),
        ("seq_near", LBFSequentialFruitPolicyWrapper(
            ordering_strategy="nearest_agent")),
        ("seq_far", LBFSequentialFruitPolicyWrapper(
            ordering_strategy="farthest_agent")),
    ]
    hardcoded_jits = {
        name: jax.jit(partial(rollout_both_ways_vs_hardcoded, hardcoded_partner=agent),
                      static_argnums=(1, 2))
        for name, agent in hardcoded_specs
    }

    # Steps to evaluate
    if args.eval_all_steps:
        steps = get_all_steps(base_xpid)
    else:
        # returns a list with full path to params.pt
        steps = get_last_step(base_xpid)

    num_seeds = args.max_seed + 1

    for step in steps:
        # params for all seeds at this step (or final)
        params_list = []
        for seed in range(num_seeds):
            if args.eval_all_steps:
                params = load_params_for_seed(
                    base_xpid, seed, step)  # step is int here
            else:
                params = load_final_params_for_seed(base_xpid, seed)
            params_list.append(params)

        ego_params_stacked = jax.tree.map(
            lambda *a: jnp.stack(a, axis=0), *params_list)

        # Evaluate vs BRDiv
        rng, _rng = jax.random.split(rng, 2)
        eval_rng = jax.random.split(_rng, 100)
        rewards = jax.vmap(
            jax.vmap(rollout_both_ways,
                     in_axes=(None, None, None, None, None, 0, None, None)),
            in_axes=(None, None, None, 0, None, None, None, None)
        )(
            eval_rng, env, network_ego, ego_params_stacked,
            partner_pop_actor, partner_pop_params, config.gru_hidden_dim, pop_size
        )

        brdiv_returns_flat = np.array(rewards).reshape(-1)
        rewards_np = np.array(rewards)  # shape: (seeds, partners)
        brdiv_cvars = [cvar(seed_returns, alpha=0.1)
                       for seed_returns in rewards_np]
        brdiv_cvar_10 = float(np.mean(brdiv_cvars))
        brdiv_cvar_10_std = float(np.std(brdiv_cvars))
        # per-pop member mean over seeds*games
        brdiv_mean = rewards.mean(axis=(0,))
        brdiv_std = rewards.std(axis=(0,))
        brdiv_overall_mean = rewards.mean()
        brdiv_overall_std = rewards.std()

        # Hardcoded evals
        init_h = ScannedRNN.initialize_carry(1, config.gru_hidden_dim)
        _rngs = jax.random.split(rng, 100)

        hardcoded_means = []
        hardcoded_stds = []
        hardcoded_cvar_10 = []
        hardcoded_names = list(hardcoded_jits.keys())
        for name in hardcoded_names:
            jit_fn = hardcoded_jits[name]
            vals, lens = jax.vmap(
                jax.vmap(jit_fn, in_axes=(0, None, None, None, None)),
                in_axes=(None, None, None, 0, None)
            )(_rngs, env, network_ego, ego_params_stacked, init_h)
            hardcoded_means.append(float(vals.mean()))
            hardcoded_stds.append(float(vals.mean(axis=(1, 2)).std()))

            vals_np = np.array(vals)
            vals_mean_across_roles = vals_np.mean(axis=2)
            hardcoded_cvars = [cvar(seed_returns, alpha=0.1)
                               for seed_returns in vals_mean_across_roles]
            hardcoded_cvar_10_float = float(np.mean(hardcoded_cvars))
            hardcoded_cvar_10.append(hardcoded_cvar_10_float)

        hardcoded_means = np.array(hardcoded_means, dtype=float)
        hardcoded_stds = np.array(hardcoded_stds,  dtype=float)
        hardcoded_cvar_10 = np.array(hardcoded_cvar_10, dtype=float)

        # Aggregate metrics
        mean_all = np.concatenate(
            [hardcoded_means, np.array([float(brdiv_overall_mean)])]).mean()
        mean_all_std = np.sqrt(
            (np.concatenate([hardcoded_stds, np.array(
                [float(brdiv_overall_std)])])**2).sum()
            / (len(hardcoded_stds) + 1)**2
        )

        cvar10_all = np.concatenate([hardcoded_cvar_10, [brdiv_cvar_10]])
        cvar10_all_mean = float(np.mean(cvar10_all))
        cvar10_all_std = float(np.std(cvar10_all))

        hardcoded_mean = hardcoded_means.mean()
        hardcoded_std = np.sqrt(
            (hardcoded_stds**2).sum() / (len(hardcoded_stds)**2))
        mean_combo = (hardcoded_mean + float(brdiv_overall_mean)) / 2.0
        mean_combo_std = np.sqrt(
            (hardcoded_std**2 + float(brdiv_overall_std)**2) / 4.0)

        method = parse_xpid(f"{base_xpid}_SEED_0")
        step_str = step if isinstance(step, int) else "final"

        # ----- CSV OUTPUT: headings on one line, values on the next -----
        # Build dynamic headings for hardcoded names
        hc_names_str = ",".join(hardcoded_names)
        hc_means_hdr = f"Hardcoded[{hc_names_str}]"
        hc_stds_hdr = f"Hardcoded_std[{hc_names_str}]"
        hc_cvar_hdr = f"HardcodedCVAR[{hc_names_str}]"

        header = (
            "xpid,method,step,"
            "Mean_all,Mean_all_std,"
            "CVAR10_all,CVAR10_all_std,"
            "BRDiv_overall,BRDiv_overall_std,BRDiv_CVaR_10,"
            f"{hc_means_hdr},"
            f"{hc_stds_hdr},"
            f"{hc_cvar_hdr},"
            "Mean_combo,Mean_combo_std"
        )

        values = [
            str(f"{base_xpid}_SEED_0"),
            method,
            str(step_str),
            f"{mean_all:.3f}",
            f"{mean_all_std:.3f}",
            f"{cvar10_all_mean:.3f}",
            f"{cvar10_all_std:.3f}",
            f"{float(brdiv_overall_mean):.3f}",
            f"{float(brdiv_overall_std):.3f}",
            f"{brdiv_cvar_10:.3f}",
            *[f"{m:.3f}" for m in hardcoded_means],
            *[f"{s:.3f}" for s in hardcoded_stds],
            *[f"{c:.3f}" for c in hardcoded_cvar_10],
            f"{mean_combo:.3f}",
            f"{mean_combo_std:.3f}",
        ]

        print(header)
        print(",".join(values))


if __name__ == "__main__":
    main()
