import os
import pickle
import jax
import jax.numpy as jnp
from pathlib import Path
from tqdm import tqdm
import numpy as np
from flax import struct

from src.agents.actors import ActorWithConditionalCritic, ScannedRNN
from src.envs.ogc.ogc import OGC, Level
from src.envs.ogc.auto_replay_wrapper import AutoReplayWrapper
from src.envs import make_env
from src.envs.log_wrapper import LogWrapper


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def rollout(rng, env, level, network, params, hidden_size, popsize):
    def _cond_fn(carry):
        _, _, stats, _, done = carry
        return (done != True).any()

    def _body_fn(carry):
        rng, env_state, stats, last_obs, done = carry
        rng, rng_action, rng_step = jax.random.split(rng, 3)

        obs_batch = batchify(last_obs, env.agents, 2)
        ac_in = (obs_batch[np.newaxis, :], jnp.zeros((1, 2, popsize)).at[:,:,1].set(1))
        pi, _ = network.apply(params, ac_in)

        action = pi.sample(seed=rng_action).squeeze()
        env_act = unbatchify(action, env.agents, 1, env.num_agents)
        env_act = {k: v.flatten().squeeze() for k, v in env_act.items()}

        obsv, env_state, reward, done, _ = env.step(rng_step, env_state, env_act)

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1,
        )
        done = batchify(done, env.agents, 2)
        return (rng, env_state, stats, obsv, done.squeeze())

    key, key_r = jax.random.split(rng)
    if type(env) == AutoReplayWrapper:
        obs, state = env.reset_env_to_level(key_r, level, env.default_params)
    else:
        obs, state = env.reset(key_r)

    carry = (rng, state, RolloutStats(), obs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def load_population(path: Path):
    population = []
    for file in path.iterdir():
        if "param" in file.name and file.suffix == ".pt":
            with open(file, "rb") as f:
                params = pickle.load(f)["actor_params"]
            population.append(params)
    return population, len(population)


def evaluate_self_play(env, network, params, level, hidden_size, num_trials=100, popsize=3):
    rngs = jax.random.split(jax.random.PRNGKey(0), num_trials)
    returns, lengths = jax.vmap(rollout, in_axes=(0, None, None, None, None, None, None))(
        rngs, env, level, network, params, hidden_size, popsize
    )
    return returns, lengths


def evaluate_population(env_type: str, layout: str, pop_path: Path):
    print(f"\n=== Evaluating in {env_type.upper()} ===")

    if env_type == "ogc":
        env = AutoReplayWrapper(OGC(width=5, height=5))
        level = Level.from_layout_name(layout)
    elif env_type == "overcooked":
        env = make_env("overcooked-v1", {"layout": layout, "random_reset": False})
        env = LogWrapper(env, replace_info=False)
        level = Level.from_layout_name(layout)
    else:
        raise ValueError("env_type must be 'ogc' or 'overcooked'")

    network = ActorWithConditionalCritic(env.action_space("agent_0").n)
    dummy_obs = jnp.zeros(env.observation_space("agent_0").shape)[None, :]
    dummy_cond = jnp.zeros((1, 1))
    dummy_input = (dummy_obs, dummy_cond)
    rng = jax.random.PRNGKey(0)
    _ = network.init(rng, dummy_input)

    population, popsize = load_population(pop_path)
    print(f"Loaded {popsize} population members from {pop_path}")

    returns_all = []
    for idx, params in tqdm(enumerate(population), total=popsize, desc=f"{env_type.upper()} Self-Play"):
        returns, _ = evaluate_self_play(env, network, params, level, hidden_size=0, num_trials=100, popsize=popsize)
        mean_return = np.mean(np.array(returns))
        returns_all.append(mean_return)
        print(f"Agent {idx}: Mean Return: {mean_return:.2f}")

    overall_mean = np.mean(returns_all)
    print(f"\n==== {env_type.upper()} Summary ====")
    print(f"Avg return over all agents: {overall_mean:.2f}")
    return returns_all


def main():
    layout = "coord_ring"
    pop_path = Path("eval_populations/FF_BRDiv") / f"{layout}"

    returns_ogc = evaluate_population("ogc", layout, pop_path)
    returns_oc = evaluate_population("overcooked", layout, pop_path)

    print("\n=== SIDE-BY-SIDE COMPARISON ===")
    for i, (r1, r2) in enumerate(zip(returns_ogc, returns_oc)):
        print(f"Agent {i}: OGC: {r1:.2f} | Overcooked: {r2:.2f}")
    print(f"\nOGC Mean: {np.mean(returns_ogc):.2f}")
    print(f"Overcooked Mean: {np.mean(returns_oc):.2f}")


if __name__ == "__main__":
    main()
