import os
import pickle
import argparse
import jax
import jax.numpy as jnp
import numpy as np
from collections import defaultdict

from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import TrainConfig as TrainConfigDPD
from src.jaxzsc.e3t.e3t_ippo_overcooked_rnn import TrainConfig as TrainConfigE3T
from src.jaxzsc.sp.sp_ippo_overcooked_rnn import TrainConfig as TrainConfigSP
from src.jaxzsc.evaluation.eval_overcooked_rnn import rollout, ActorCriticRNN, ScannedRNN, RolloutStats

from src.envs import make_env
from src.envs.log_wrapper import LogWrapper


def load_config_and_params(xpid):
    save_dir = f"checkpoints/{xpid}"

    with open(f"{save_dir}/config.pckl", "rb") as f:
        loaded_dict = pickle.load(f)

    if "E3T" in xpid:
        config = TrainConfigE3T(**loaded_dict)
    elif "DPD" in xpid:
        config = TrainConfigDPD(**loaded_dict)
    elif "SP" in xpid:
        config = TrainConfigSP(**loaded_dict)
    else:
        raise ValueError(f"Unknown config type for XPID: {xpid}")

    with open(f"{save_dir}/params.pt", "rb") as f:
        params = pickle.load(f)["actor_params"]

    return config, params


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_xpid", type=str, required=True)
    parser.add_argument("--max_seed", type=int, required=True)
    args = parser.parse_args()

    prefix = args.base_xpid.rsplit("_SEED_", 1)[0]
    task_name = args.base_xpid.split("Overcooked_")[1].rsplit("_SEED_", 1)[0]

    num_seeds = args.max_seed + 1
    results = np.zeros((num_seeds, num_seeds))
    total_sum = diag_sum = offdiag_sum = 0.0
    total_count = diag_count = offdiag_count = 0

    config, _ = load_config_and_params(f"{prefix}_SEED_0")
    env = make_env(
        "overcooked-v1", {"layout": config.layout_name})

    if "SP" in args.base_xpid:
        network = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=False,
            use_layernorm=False,
        )
    else:
        network = ActorCriticRNN(
            env.action_space("agent_0").n,
            gru_hidden_dim_size=config.gru_hidden_dim,
            fc_dim_size=config.fc_dim_size,
            embedding_layers=config.embedding_layers,
            actor_layers=config.actor_layers,
            critic_layers=config.critic_layers,
            other_agent_prediction=config.other_agent_prediction,
            use_layernorm=config.use_layernorm,
        )

    rng = jax.random.PRNGKey(0)
    rng = jax.random.split(rng, 1024)

    # Preload all params and hstates
    params_list = []
    hstate_list = []
    for seed in range(num_seeds):
        _, params = load_config_and_params(f"{prefix}_SEED_{seed}")
        params_list.append(params)
        hstate = ScannedRNN.initialize_carry(1024, config.gru_hidden_dim)
        hstate_list.append(hstate)

    # Evaluate all pairs
    for i in range(num_seeds):
        for j in range(num_seeds):
            reward, l = jax.vmap(rollout, in_axes=(0, None, None, None, None))(
                rng,
                env,
                network,
                params_list[i],
                params_list[j],
                config.gru_hidden_dim
            )
            avg_reward = reward.mean()
            results[i, j] = float(avg_reward)

            print(f"Evaluating pair {i} vs {j}: {avg_reward}")

            # Aggregation
            total_sum += avg_reward
            total_count += 1
            if i == j:
                diag_sum += avg_reward
                diag_count += 1
            else:
                offdiag_sum += avg_reward
                offdiag_count += 1

    # Print result matrix
    print(f"\nResult Matrix for task '{task_name}' (comma-separated):")
    header = f"{task_name}," + ",".join([f"S{j}" for j in range(num_seeds)])
    print(header)
    for i in range(num_seeds):
        row = f"S{i}," + \
            ",".join([f"{results[i, j]:.4f}" for j in range(num_seeds)])
        print(row)

    # Averages
    avg_total = total_sum / total_count
    avg_diag = diag_sum / diag_count if diag_count else 0
    avg_offdiag = offdiag_sum / offdiag_count if offdiag_count else 0

    print("AVG,SP,XP")
    print(f"{avg_total:.4f},{avg_diag:.4f},{avg_offdiag:.4f}")

    print("---------")
    if type(config) == TrainConfigDPD:
        print(f"{args.base_xpid},{config.learnability_function},{config.layout_name}")


if __name__ == "__main__":
    main()
