"""Evaluates two XPIDs together that share the same network architectures
"""
import argparse
import pickle
from exceptiongroup import catch
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import glob
from flax.linen.initializers import constant, orthogonal
from typing import Sequence
import distrax
from flax import struct

from typing import Sequence

from src.envs import make_env
from src.envs.log_wrapper import LogWrapper

from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import TrainConfig as TrainConfigDPD
from src.jaxzsc.e3t.e3t_ippo_overcooked_rnn import TrainConfig as TrainConfigE3T
from src.jaxzsc.sp.sp_ippo_overcooked_rnn import TrainConfig as TrainConfigSP
from src.jaxzsc.brdiv.brdiv_ippo_overcooked import TrainConfig as TrainConfigBRDiv

from src.jaxzsc.brdiv.brdiv_ippo_overcooked import ActorWithConditionalCritic
from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import ScannedRNN, ActorCriticRNN


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def rollout(
    rng,
    env,
    network1,
    network2,
    params1,
    params2,
    init_hstate,  # Only the ego agent is recurrent
    popsize,
):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_step = jax.random.split(rng, 3)

        agent_positions = {
            'agent_0': env_state.env_state.agent_pos,
            'agent_1': env_state.env_state.agent_pos
        }

        if type(network1) == ActorWithConditionalCritic:
            ac_in_1 = (
                last_obs["agent_0"].reshape(-1)[np.newaxis, np.newaxis, :],
                jnp.array([0]*popsize)[np.newaxis, np.newaxis, :],
            )

            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
            ac_in_2 = (
                last_obs["agent_1"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs,
                # agent_positions["agent_0"][np.newaxis, np.newaxis, :],
            )
            pi1, value = network1.apply(
                params1, ac_in_1)
            action1 = pi1.sample(seed=rng_action).squeeze()

            hstate, pi2, value, other_pi = network2.apply(
                params2, hstate, ac_in_2)
            action2 = pi2.sample(seed=rng_action).squeeze()
        else:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
            ac_in_1 = (
                last_obs["agent_0"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs,
                # agent_positions["agent_0"][np.newaxis, np.newaxis, :],
            )

            ac_in_2 = (
                last_obs["agent_1"].reshape(-1)[np.newaxis, np.newaxis, :],
                jnp.array([0]*popsize)[np.newaxis, np.newaxis, ...],
            )
            hstate, pi1, value, other_pi = network1.apply(
                params1, hstate, ac_in_1)
            action1 = pi1.sample(seed=rng_action).squeeze()

            pi2, value = network2.apply(
                params2, ac_in_2)
            action2 = pi2.sample(seed=rng_action).squeeze()

        env_act = {
            "agent_0": action1,
            "agent_1": action2
        }

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)

        carry = (rng, env_state, stats, obsv,
                 hstate, past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_carry = (
        rng, state, RolloutStats(), obs,
        init_hstate[np.newaxis, ...], past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_both_ways(
    rng,
    env,
    network,
    network2,
    params1,
    params2,
    gru_hidden_dim,
    popsize,
):
    rng = jax.random.split(rng, 1000)
    init_hstate = ScannedRNN.initialize_carry(
        1000, gru_hidden_dim)
    reward, len = jax.vmap(rollout, in_axes=(0, None, None, None, None, None, 0, None))(
        rng,
        env,
        network,
        network2,
        params1,
        params2,
        init_hstate,
        popsize,
    )
    reward2, len = jax.vmap(rollout, in_axes=(0, None, None, None, None, None, 0, None))(
        rng,
        env,
        network2,
        network,
        params2,
        params1,
        init_hstate,
        popsize,
    )

    return jnp.array([reward, reward2])  # (reward + reward2) / 2


def interquartile_mean_vec(x, axis=-1):
    """
    Compute IQM along specified axis of x (vectorized).

    Args:
        x: JAX array of shape (..., N)
        axis: axis along which to compute IQM (default: last axis)

    Returns:
        iqm: JAX array of shape (...), with IQM along the given axis.
    """
    # Sort along the axis
    x_sorted = jnp.sort(x, axis=axis)

    n = x.shape[axis]

    # Compute indices for 25th and 75th percentiles
    q1_idx = int(jnp.floor(0.25 * n))
    q3_idx = int(jnp.ceil(0.75 * n))

    # Slice the interquartile range
    slices = [slice(None)] * x.ndim
    slices[axis] = slice(q1_idx, q3_idx)
    interquartile_slice = x_sorted[tuple(slices)]

    # Mean over the axis
    iqm = jnp.mean(interquartile_slice, axis=axis)

    return iqm


def load_config_and_params(xpid, seed):
    save_dir = f"checkpoints/{xpid}"

    with open(f"{save_dir}/config.pckl", "rb") as f:
        loaded_dict = pickle.load(f)

    if "E3T" in xpid:
        config = TrainConfigE3T(**loaded_dict)
    elif "DPD" in xpid:
        config = TrainConfigDPD(**loaded_dict)
    elif "SP" in xpid:
        config = TrainConfigSP(**loaded_dict)
    else:
        raise ValueError(f"Unknown config type for XPID: {xpid}")

    try:
        with open(f"{save_dir}/params.pt", "rb") as f:
            params = pickle.load(f)["actor_params"]
    except FileNotFoundError:
        with open(f"{save_dir}/params_seed{seed}.pt", "rb") as f:
            params = pickle.load(f)["actor_params"]

    return config, params


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--base_xpid',
        type=str,
        default=None,
        help='First XPID')

    parser.add_argument(
        '--max_seed',
        type=int,
        default=None,
        help='First XPID')
    args = parser.parse_args()

    xpid = args.base_xpid
    num_seeds = args.max_seed + 1

    prefix = args.base_xpid.rsplit("_SEED_", 1)[0]

    save_dir1 = f"checkpoints/{xpid}"

    with open(f"{save_dir1}/config.pckl", 'rb') as f:
        loaded_dict = pickle.load(f)

    if "E3T" in xpid:
        config = TrainConfigE3T(**loaded_dict)
    elif "DPD" in xpid:
        config = TrainConfigDPD(**loaded_dict)
    elif "SP" in xpid:
        config = TrainConfigSP(**loaded_dict)
    else:
        raise ValueError("Not Supported.")

    params_list = []
    hstate_list = []
    for seed in range(num_seeds):
        _, params = load_config_and_params(f"{prefix}_SEED_{seed}", seed)
        params_list.append(params)
        hstate = ScannedRNN.initialize_carry(1024, config.gru_hidden_dim)
        hstate_list.append(hstate)

    #### Load population ####
    brdiv_config = f"eval_populations/FF_BRDiv/{config.layout_name}/config.pckl"

    file_paths = glob.glob(
        f"eval_populations/FF_BRDiv/{config.layout_name}/params_*.pt")
    populations = []
    for path in file_paths:
        with open(path, 'rb') as f:
            data = pickle.load(f)
            populations.append(data)

    with open(brdiv_config, 'rb') as f:
        config_brdiv = pickle.load(f)

    config_brdiv = TrainConfigBRDiv(**config_brdiv)

    # Stack these:
    stacked_population = jax.tree.map(
        lambda *arrays: jnp.stack(arrays, axis=0),
        *populations
    )
    stacked_population = stacked_population["actor_params"]
    #########################

    env = make_env(
        "overcooked-v1", {"layout": config.layout_name, "random_reset": False})

    rng = jax.random.PRNGKey(0)

    if "SP" in xpid:
        network_ego = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=False,
            use_layernorm=False,
        )
    else:
        network_ego = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=config.other_agent_prediction,
            use_layernorm=config.use_layernorm,
        )
    network_brdiv = ActorWithConditionalCritic(
        env.action_space("agent_1").n,
        activation=config_brdiv.activation,
    )

    breakpoint()

    rollout_both_ways_jit = jax.jit(
        rollout_both_ways, static_argnums=(1, 2, 3, 6, 7))

    iqms = []
    for i in range(num_seeds):
        reward = jax.vmap(rollout_both_ways_jit, in_axes=(
            None, None, None, None, None, 0, None, None
        ))(
            rng,
            env,
            network_ego,
            network_brdiv,
            params_list[i],
            stacked_population,
            config.gru_hidden_dim,
            config_brdiv.partner_pop_size,
        )

        iqms.append(interquartile_mean_vec(reward.flatten()))
    print(np.array(iqms).mean(), np.array(iqms).std())


if __name__ == '__main__':
    main()
