import os
import re
import glob
import pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import distrax

from flax import struct


from src.envs import make_env
from src.agents.actors import ScannedRNN, ActorCriticRNN
from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import TrainConfig

# ------------ Config ------------ #
NUM_PARTNERS = 8192
ROLLOUTS_PER_PARTNER = 10
CHECKPOINT_REGEX = r"params_(\d+)_\d+\.\d+\.pt"

# ------------ Utility Functions ------------ #

class PartnerParametersWithBias(struct.PyTreeNode):
    epsilon: jnp.float32
    epsilon_agent: jnp.int32
    bias_mask: jnp.ndarray  # shape (ACTION_SPACE_SIZE,)


def sample_partner_parameters(num_agents, rng: jax.random.PRNGKey) -> PartnerParametersWithBias:
    """
    Sampling partner parameters with 50% chance to inject Dirichlet-based action bias.
    """
    ACTION_SPACE_SIZE = 6
    DIRICHLET_ALPHA = 1.0

    rng, rng_eps, rng_eps_agent, rng_bias_decision, rng_dirichlet = jax.random.split(
        rng, 5)

    epsilon = jax.random.uniform(rng_eps, shape=(), minval=0.0, maxval=1.0)
    epsilon_agent = jax.random.bernoulli(rng_eps_agent)

    # 50% chance to apply Dirichlet bias
    def biased_mask():
        return jax.random.dirichlet(rng_dirichlet, alpha=DIRICHLET_ALPHA * jnp.ones(ACTION_SPACE_SIZE))

    def uniform_mask():
        return jnp.ones((ACTION_SPACE_SIZE,)) / ACTION_SPACE_SIZE

    bias_mask = uniform_mask()

    return PartnerParametersWithBias(
        epsilon=epsilon,
        epsilon_agent=epsilon_agent,
        bias_mask=bias_mask
    )

def extract_checkpoints(path):
    return sorted([
        f for f in os.listdir(path)
        if re.match(CHECKPOINT_REGEX, f)
    ], key=lambda x: int(re.search(CHECKPOINT_REGEX, x).group(1)))


def load_config(path):
    with open(os.path.join(path, "config.pckl"), "rb") as f:
        return TrainConfig(**pickle.load(f))


def load_checkpoint(path):
    with open(path, "rb") as f:
        return pickle.load(f)["actor_params"]

def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def batchify_nested_dics(x: dict, agent_list, shape):
    data = [x[a] for a in agent_list]
    tree = jax.tree.map(lambda *v: jnp.stack(v), *data)
    return jax.tree.map(lambda x: x.reshape((*shape, -1)), tree)


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def rollout_nsteps(
    rng: jax.Array,
    env,
    partner_params,
    network,
    params,
    init_hstate: jax.Array,
    num_steps: int = 400,
    sfl_rollout_factor: int = 2,
    use_dense_rewards: bool = False,
    dense_rewards_coeff: jax.Array = jnp.asarray(0.0),
):
    """ Rollout for `num_steps` environment steps """

    class RolloutEpisodeStats(struct.PyTreeNode):
        reward: jax.Array = jnp.asarray(0.0)
        episode_return: jax.Array = jnp.zeros((sfl_rollout_factor,))
        length: jax.Array = jnp.asarray(0)
        episode_counter: jax.Array = jnp.asarray(0)
        done: jax.Array = jnp.asarray(False)

    def _env_step(carry, unused):
        rng, env_state, stats, last_obs, last_done, hstate, past_5_sa_pairs = carry
        rng, rng_action, rng_step = jax.random.split(rng, 3)

        obs_batch = batchify(last_obs, env.agents, 2)

        def get_e3t_action(args):
            pi_ego, k, e3t_epsilon, bias_mask = args
            pi_random = distrax.Categorical(probs=bias_mask)
            pi_e3t_probs = (1 - e3t_epsilon) * pi_ego.probs + \
                e3t_epsilon * pi_random.probs
            pi_e3t = distrax.Categorical(probs=pi_e3t_probs)
            sampled_a = pi_e3t.sample(seed=k)
            log_prob_a = pi_e3t.log_prob(sampled_a)
            entropy_a = pi_e3t.entropy()
            return sampled_a, log_prob_a, entropy_a

        def get_base_action(args):
            pi_ego, k, e3t_epsilon, bias_mask = args
            sampled_a = pi_ego.sample(seed=k)
            log_prob_a = pi_ego.log_prob(sampled_a)
            entropy_a = pi_ego.entropy()
            return sampled_a, log_prob_a, entropy_a

        batched_sa_pairs = batchify_nested_dics(
            past_5_sa_pairs, env.agents, (1, 2, 5))
        ac_in = (
            obs_batch[np.newaxis, :],
            last_done[np.newaxis, :],
            batched_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(
            params, hstate, ac_in)

        rng, _rng = jax.random.split(rng, 2)
        e3t_action, e3t_log_prob, e3t_entropy = get_e3t_action(
            (pi, _rng, partner_params.epsilon, partner_params.bias_mask))
        base_action, base_log_prob, base_entropy = get_base_action(
            (pi, _rng, partner_params.epsilon, partner_params.bias_mask))

        epsilon_agent_both = jnp.array(
            [partner_params.epsilon_agent, ~partner_params.epsilon_agent])
        action = jnp.where(epsilon_agent_both, e3t_action, base_action)
        action = action.squeeze()

        log_prob = jnp.where(epsilon_agent_both,
                             e3t_log_prob, base_log_prob).squeeze()
        entropy = jnp.where(epsilon_agent_both, e3t_entropy,
                            base_entropy).squeeze()

        env_act = unbatchify(action, env.agents, 1, env.num_agents)
        env_act = {k: v.flatten().squeeze() for k, v in env_act.items()}

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        # STEP ENV
        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act)

        done_flag = done["__all__"]
        final_episode_return = stats.reward + reward["agent_0"]

        # Update buffer only if done, otherwise leave as is
        new_episode_return = jax.lax.cond(
            done_flag,
            lambda: stats.episode_return.at[stats.episode_counter].set(final_episode_return),
            lambda: stats.episode_return,
        )

        # Reset reward accumulator if done
        new_reward = jax.lax.cond(
            done_flag,
            lambda: jnp.array(0.0),
            lambda: final_episode_return,
        )

        # Increment episode counter only if done
        new_counter = stats.episode_counter + done_flag.astype(jnp.int32)

        # Update stats
        stats = stats.replace(
            reward=new_reward,
            length=stats.length + 1,
            done=done_flag,
            episode_counter=new_counter,
            episode_return=new_episode_return,
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv,
                 done.squeeze(), hstate, past_5_sa_pairs)
        return carry, None

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_carry = (rng, state, RolloutEpisodeStats(), obs,
                  jnp.array([False, False]), init_hstate, past_5_sa_pairs)

    final_carry, _ = jax.lax.scan(
        _env_step, init_carry, None, length=num_steps)

    return final_carry[2], final_carry[1]


# ------------ Main Analysis ------------ #
def run_analysis(base_xpid):
    xpid_seed = f"{base_xpid}"
    ckpt_dir = os.path.join("checkpoints", xpid_seed)

    config = load_config(ckpt_dir)
    checkpoints = extract_checkpoints(ckpt_dir)
    env = make_env("overcooked-v1", {"layout": config.layout_name, "random_reset": False})

    network = ActorCriticRNN(
        env.action_space("agent_0").n,
        gru_hidden_dim_size=config.gru_hidden_dim,
        fc_dim_size=config.fc_dim_size,
        embedding_layers=config.embedding_layers,
        actor_layers=config.actor_layers,
        critic_layers=config.critic_layers,
        other_agent_prediction=config.other_agent_prediction,
        use_layernorm=config.use_layernorm,
    )

    rng = jax.random.PRNGKey(config.seed)
    results_per_ckpt = []

    for ckpt_file in tqdm(checkpoints, desc="Analyzing checkpoints"):
        step = int(re.search(CHECKPOINT_REGEX, ckpt_file).group(1))
        ckpt_path = os.path.join(ckpt_dir, ckpt_file)
        params = load_checkpoint(ckpt_path)

        rng, rng_sample, rng_roll = jax.random.split(rng, 3)
        partner_rngs = jax.random.split(rng_sample, NUM_PARTNERS)
        partner_params = jax.vmap(sample_partner_parameters, in_axes=(None, 0))(2, partner_rngs)

        # Generate all PRNG keys for rollout
        rollout_keys = jax.random.split(rng_roll, NUM_PARTNERS)

        init_h = ScannedRNN.initialize_carry(2, config.gru_hidden_dim)

        stats, _ = jax.vmap(rollout_nsteps, in_axes=(0, None, 0, None, None, None, None, None, None, None))(
            rollout_keys,
            env,
            partner_params,
            network,
            params,
            init_h,
            400*ROLLOUTS_PER_PARTNER,
            ROLLOUTS_PER_PARTNER,
            False,
            0.0,
        )

        learnabilities = jnp.var(
                    stats.episode_return, axis=-1).squeeze()
        mean_returns = stats.episode_return.mean(axis=-1)
        epsilons = partner_params.epsilon

        results_per_ckpt.append({
            "step": step,
            "learnability": learnabilities,
            "mean_return": mean_returns,
            "epsilon": epsilons,
        })

    return results_per_ckpt, config.layout_name


# ------------ Plotting ------------ #
def plot_results(results_per_ckpt, save_dir, seed, layout):
    min_return = min(min(d['mean_return']) for d in results_per_ckpt)
    max_return = max(max(d['mean_return']) for d in results_per_ckpt)
    max_learn = max(max(d['learnability']) for d in results_per_ckpt)

    steps = sorted([d['step'] for d in results_per_ckpt])
    step_to_stage = {
        steps[0]: "Early",
        steps[1]: "Middle",
        steps[2]: "End"
    }

    layout_to_display = {
        "coord_ring": "CR",
        "cramped_room": "CRoom",
        "forced_coord": "FC",
        "asymm_advantages": "AA",
        "counter_circuit": "CC"
    }

    for shared_limits in [True, False]:
        suffix = "" if shared_limits else "_non_shared_limits"
        target_dir = os.path.join(save_dir + suffix)
        os.makedirs(target_dir, exist_ok=True)

        for res in results_per_ckpt:
            step = res["step"]
            stage = step_to_stage[step]

            learn = res["learnability"]
            ret = res["mean_return"]
            eps = res["epsilon"]

            local_min_ret = min_return if shared_limits else min(ret)
            local_max_ret = max_return if shared_limits else max(ret)
            local_max_learn = max_learn if shared_limits else max(learn)

            # Plot with histograms on top/right axes as bar plots
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 7))

            # Subplot 1: Return vs Learnability with histograms
            ax1.scatter(ret, learn, alpha=0.2, s=5)
            try:
                coeffs = np.polyfit(ret, learn, deg=2)
                x_fit = np.linspace(min(ret), max(ret), 200)
                y_fit = np.polyval(coeffs, x_fit)
                ax1.plot(x_fit, y_fit, color='blue', linewidth=2, label='Quadratic fit')
                ax1.legend()
            except Exception as e:
                print(f"Curve fitting failed (Return vs Learnability): {e}")
            ax1.set_xlim(local_min_ret, local_max_ret + 10)
            ax1.set_ylim(bottom=0)
            ax1.set_xlabel("Mean Return")
            ax1.set_ylabel("Learnability (Variance)")
            ax1.set_title(f"{layout_to_display[layout]}: Return vs Learnability ({stage})")

            hist_counts, bin_edges = np.histogram(ret, bins=10, range=(local_min_ret, local_max_ret + 10))
            bin_width = bin_edges[1] - bin_edges[0]
            ax1_hist_top = ax1.inset_axes([0, 1.05, 1, 0.2], transform=ax1.transAxes)
            ax1_hist_top.bar(bin_edges[:-1], hist_counts, width=bin_width, align='edge', edgecolor='black')
            ax1_hist_top.set_xlim(ax1.get_xlim())
            ax1_hist_top.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)

            learn_hist_counts, learn_bin_edges = np.histogram(learn, bins=10, range=(0, local_max_learn))
            learn_bin_width = learn_bin_edges[1] - learn_bin_edges[0]
            ax1_hist_right = ax1.inset_axes([1.02, 0, 0.2, 1], transform=ax1.transAxes)
            ax1_hist_right.barh(learn_bin_edges[:-1], learn_hist_counts, height=learn_bin_width, align='edge', edgecolor='black')
            ax1_hist_right.set_ylim(ax1.get_ylim())
            ax1_hist_right.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)

            # Subplot 2: Epsilon vs Learnability with histograms
            ax2.scatter(eps, learn, alpha=0.2, s=5, color='orange')
            try:
                coeffs = np.polyfit(eps, learn, deg=2)
                x_fit = np.linspace(min(eps), max(eps), 200)
                y_fit = np.polyval(coeffs, x_fit)
                ax2.plot(x_fit, y_fit, color='red', linewidth=2, label='Quadratic fit')
                ax2.legend()
            except Exception as e:
                print(f"Curve fitting failed (Epsilon vs Learnability): {e}")
            ax2.set_ylim(bottom=0)
            ax2.set_xlabel("Epsilon")
            ax2.set_ylabel("Learnability (Variance)")
            ax2.set_title(f"{layout_to_display[layout]}: Epsilon vs Learnability ({stage})")

            hist_eps_counts, eps_bin_edges = np.histogram(eps, bins=10, range=(0.0, 1.0))
            eps_bin_width = eps_bin_edges[1] - eps_bin_edges[0]
            ax2_hist_top = ax2.inset_axes([0, 1.05, 1, 0.2], transform=ax2.transAxes)
            ax2_hist_top.bar(eps_bin_edges[:-1], hist_eps_counts, width=eps_bin_width, align='edge', color='orange', edgecolor='black')
            ax2_hist_top.set_xlim(ax2.get_xlim())
            ax2_hist_top.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)

            learn_hist_counts2, learn_bin_edges2 = np.histogram(learn, bins=10, range=(0, local_max_learn))
            learn_bin_width2 = learn_bin_edges2[1] - learn_bin_edges2[0]
            ax2_hist_right = ax2.inset_axes([1.02, 0, 0.2, 1], transform=ax2.transAxes)
            ax2_hist_right.barh(learn_bin_edges2[:-1], learn_hist_counts2, height=learn_bin_width2, align='edge', color='orange', edgecolor='black')
            ax2_hist_right.set_ylim(ax2.get_ylim())
            ax2_hist_right.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)

            plt.tight_layout()
            plt.savefig(os.path.join(target_dir, f"{layout}_learnability_step_{step}_seed_{seed}_with_hist.png"))
            plt.close()

            # Return-only with histogram above and right as bar plots
            fig, ax_ret_hist = plt.subplots(figsize=(6, 4))
            ax_ret_hist.scatter(ret, learn, alpha=0.2, s=5)
            try:
                coeffs = np.polyfit(ret, learn, deg=2)
                x_fit = np.linspace(min(ret), max(ret), 200)
                y_fit = np.polyval(coeffs, x_fit)
                ax_ret_hist.plot(x_fit, y_fit, color='blue', linewidth=2, label='Quadratic fit')
                ax_ret_hist.legend()
            except Exception as e:
                print(f"Curve fitting failed (Return-only): {e}")

            ax_ret_hist.set_ylim(bottom=0)
            ax_ret_hist.set_xlim(local_min_ret, local_max_ret + 10)
            ax_ret_hist.set_xlabel("Mean Return")
            ax_ret_hist.set_ylabel("Learnability (Variance)")
            ax_ret_hist.set_title(f"{layout_to_display[layout]}: Return vs Learnability ({stage})")

            hist_counts, bin_edges = np.histogram(ret, bins=10, range=(local_min_ret, local_max_ret + 10))
            bin_width = bin_edges[1] - bin_edges[0]
            ax_hist_top = ax_ret_hist.inset_axes([0, 1.05, 1, 0.2], transform=ax_ret_hist.transAxes)
            ax_hist_top.bar(bin_edges[:-1], hist_counts, width=bin_width, align='edge', edgecolor='black')
            ax_hist_top.set_xlim(ax_ret_hist.get_xlim())
            ax_hist_top.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)

            learn_hist_counts, learn_bin_edges = np.histogram(learn, bins=10, range=(0, local_max_learn))
            learn_bin_width = learn_bin_edges[1] - learn_bin_edges[0]
            ax_hist_right = ax_ret_hist.inset_axes([1.02, 0, 0.2, 1], transform=ax_ret_hist.transAxes)
            ax_hist_right.barh(learn_bin_edges[:-1], learn_hist_counts, height=learn_bin_width, align='edge', edgecolor='black')
            ax_hist_right.set_ylim(ax_ret_hist.get_ylim())
            ax_hist_right.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)

            plt.tight_layout()
            plt.savefig(os.path.join(target_dir, f"return_only_{layout}_learnability_step_{step}_seed_{seed}_with_hist.png"))
            plt.close()

# ------------ Entry Point ------------ #
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_xpid", type=str, required=True,
                        help="Base XPID without _SEED_0")
    parser.add_argument("--save_dir", type=str, default="figures_learnability",
                        help="Directory to save figures")
    args = parser.parse_args()

    results, layout = run_analysis(args.base_xpid)
    match = re.search(r"SEED_(\d+)", args.base_xpid)
    seed = int(match.group(1)) if match else None

    plot_results(results, args.save_dir, seed, layout)
