import argparse
import os
import time

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import kl_div
from jumanji.types import TimeStep

from envs.triangulation import Triangulation, TriangulationState
from flows.distribution_embedding import FlowEmbedding
from global_utils import build_continuous_embedding_net, load_ckpt, select_last_ckpt
from triangulation.filters import NeuralBayesFilter, ParticleFilter
from triangulation.train_flow import BiasedAgent


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    # fmt: off
    parser.add_argument('--num_episodes', type=int, default=100)
    parser.add_argument('--max_steps', type=int, default=10)
    parser.add_argument('--num_pf_particles', type=int, default=[16, 32], nargs='+', help='PF sizes to evaluate')
    parser.add_argument('--num_nbf_particles', type=int, default=[16, 32], nargs='+', help='NBF sizes to evaluate')
    parser.add_argument('--ref_particles', type=int, default=1024)
    parser.add_argument('--grid_size', type=int, default=32, help='Histogram grid size for JS div')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--flow_run_name', type=str, default='flow', help='Name of the NBF training run to evaluate')
    parser.add_argument('--model_dir', type=str, default='triangulation_models', help='Directory containing model checkpoints')
    # fmt: on

    return parser.parse_args()


def particles_to_dist(particles: TriangulationState, weights: jax.Array, grid_size: int):
    """Converts a set of weighted particles to a discrete distribution on a grid."""
    # Use grid_size bins, which requires grid_size + 1 edges.
    bins = jnp.linspace(-5.0, 5.0, grid_size + 1)
    weights = weights + 1e-9
    hist = jnp.histogram2d(
        x=particles.player_location[:, 0],
        y=particles.player_location[:, 1],
        bins=bins,
        weights=weights / jnp.sum(weights),
    )[0]
    return hist.flatten()


def nbf_to_dist(
    embedding: jax.Array,
    model: FlowEmbedding,
    active_beacon: int,
    grid_size: int,
    key: jax.Array,
):
    xs = jnp.linspace(-5.0, 5.0, grid_size)
    ys = jnp.linspace(-5.0, 5.0, grid_size)
    grid_x, grid_y = jnp.meshgrid(xs, ys)
    locations = jnp.stack([grid_x.flatten(), grid_y.flatten()], axis=-1)
    beacon_col = jnp.full((locations.shape[0], 1), active_beacon.astype(jnp.float32))
    features = jnp.concatenate([locations, beacon_col], axis=1)
    log_probs, *_ = model.log_prob(features, embedding, key=key)
    probs = jnp.exp(log_probs)
    return probs / jnp.sum(probs)


@eqx.filter_jit
def js_divergence(p, q):
    m = 0.5 * (p + q)
    eps = 1e-12
    p_c = jnp.clip(p, eps, 1.0)
    q_c = jnp.clip(q, eps, 1.0)
    m_c = jnp.clip(m, eps, 1.0)
    return jnp.sum(0.5 * kl_div(p_c, m_c) + 0.5 * kl_div(q_c, m_c))


@eqx.filter_jit
def evaluate_tracking_step_pf(
    env: Triangulation,
    pf_eval: ParticleFilter,
    pf_ref: ParticleFilter,
    agent: BiasedAgent,
    state: TriangulationState,
    timestep: TimeStep,
    fs_eval: jax.Array,
    fs_ref: jax.Array,
    grid_size: int,
    key: jax.Array,
):
    key, act_key, step_key, update_key = jax.random.split(key, 4)
    action = agent(state, act_key)
    next_state, next_time_step = env.step(state, action, step_key)

    next_fs_eval, _ = pf_eval.update(
        env, fs_eval, state, timestep.observation, action, next_time_step.observation, update_key
    )
    next_fs_ref, _ = pf_ref.update(
        env, fs_ref, state, timestep.observation, action, next_time_step.observation, update_key
    )

    dist_eval = particles_to_dist(next_fs_eval.particles, next_fs_eval.weights, grid_size)
    dist_ref = particles_to_dist(next_fs_ref.particles, next_fs_ref.weights, grid_size)
    divergence = js_divergence(dist_eval, dist_ref)
    return next_state, next_fs_eval, next_fs_ref, divergence


@eqx.filter_jit
def evaluate_tracking_step_nbf(
    env: Triangulation,
    nbf_eval: NeuralBayesFilter,
    pf_ref: ParticleFilter,
    agent: BiasedAgent,
    state: TriangulationState,
    timestep: TimeStep,
    fs_eval: jax.Array,
    fs_ref: jax.Array,
    grid_size: int,
    key: jax.Array,
):
    key, act_key, step_key, update_key, dist_key = jax.random.split(key, 5)
    action = agent(state, act_key)
    next_state, next_time_step = env.step(state, action, step_key)

    next_fs_eval, _ = nbf_eval.update(
        env, fs_eval, state, timestep.observation, action, next_time_step.observation, update_key
    )
    next_fs_ref, _ = pf_ref.update(
        env, fs_ref, state, timestep.observation, action, next_time_step.observation, update_key
    )

    dist_eval = nbf_to_dist(
        next_fs_eval, nbf_eval.model, next_state.active_beacon, grid_size, dist_key
    )
    dist_ref = particles_to_dist(next_fs_ref.particles, next_fs_ref.weights, grid_size)
    divergence = js_divergence(dist_eval, dist_ref)
    return next_state, next_time_step, next_fs_eval, next_fs_ref, divergence


def run_episode_eval_pf(
    env: Triangulation,
    pf_eval: ParticleFilter,
    pf_ref: ParticleFilter,
    agent: BiasedAgent,
    max_steps: int,
    grid_size: int,
    key: jax.Array,
):
    key, reset_key = jax.random.split(key)
    state, time_step = env.reset(reset_key)
    fs_eval = pf_eval.reset(env, state, time_step.observation, reset_key)
    fs_ref = pf_ref.reset(env, state, time_step.observation, reset_key)

    total_divergence = 0.0
    for _ in range(max_steps):
        key, step_key = jax.random.split(key)
        state, fs_eval, fs_ref, divergence = evaluate_tracking_step_pf(
            env, pf_eval, pf_ref, agent, state, time_step, fs_eval, fs_ref, grid_size, step_key
        )
        total_divergence += divergence

    return total_divergence / max_steps


def run_episode_eval_nbf(
    env: Triangulation,
    nbf_eval: NeuralBayesFilter,
    pf_ref: ParticleFilter,
    agent: BiasedAgent,
    max_steps: int,
    grid_size: int,
    key: jax.Array,
):
    key, reset_key = jax.random.split(key)
    state, time_step = env.reset(reset_key)
    fs_eval = nbf_eval.reset(env, state, time_step.observation, reset_key)
    fs_ref = pf_ref.reset(env, state, time_step.observation, reset_key)

    total_divergence = 0.0
    for _ in range(max_steps):
        key, step_key = jax.random.split(key)
        state, time_step, fs_eval, fs_ref, divergence = evaluate_tracking_step_nbf(
            env, nbf_eval, pf_ref, agent, state, time_step, fs_eval, fs_ref, grid_size, step_key
        )
        total_divergence += divergence

    return total_divergence / max_steps


def _parse_list_arg(s: str) -> list[int]:
    return [int(x) for x in s.split(',') if x.strip()]


def main() -> None:
    args = parse_args()
    key = jax.random.key(args.seed)

    env = Triangulation()
    pf_ref = ParticleFilter(args.ref_particles)

    results = {}

    # Determine the number of models to use for statistical significance
    num_models = 0
    if args.flow_run_name:
        all_model_dirs = os.listdir(os.path.join(args.model_dir, args.flow_run_name))
        num_models = len(all_model_dirs)

    if num_models == 0:
        print('Warning: No NBF models found or specified. PF evaluation will run once.')
        num_models = 1  # Run PF at least once

    for n_particles in args.num_pf_particles:
        print(f'--- Evaluating PF with {n_particles} particles ({num_models} runs) ---')
        run_mean_jsds = []
        start_time = time.time()
        pf_eval = ParticleFilter(num_particles=n_particles)

        for i in range(num_models):
            key, run_key = jax.random.split(key)
            total_jsd = 0.0
            for _ in range(args.num_episodes):
                run_key, agent_key, ep_key = jax.random.split(run_key, 3)
                direction = jax.random.randint(agent_key, shape=(), minval=0, maxval=4)
                agent = BiasedAgent(direction=direction)
                jsd = run_episode_eval_pf(
                    env, pf_eval, pf_ref, agent, args.max_steps, args.grid_size, ep_key
                )
                total_jsd += jsd
            mean_jsd_for_run = total_jsd / args.num_episodes
            print(f'  Run {i + 1}/{num_models} Mean JSD: {mean_jsd_for_run:.4f}')
            run_mean_jsds.append(float(np.asarray(mean_jsd_for_run)))

        elapsed = time.time() - start_time
        mean_of_means = np.mean(run_mean_jsds)
        stderr = np.std(run_mean_jsds) / np.sqrt(len(run_mean_jsds))
        print(
            f'  --> PF-{n_particles} Mean of Means: {mean_of_means:.4f} (Total time: {elapsed:.2f}s)'
        )
        print(f'  --> PF-{n_particles} StdErr: {stderr:.4f}')
        results[f'PF-{n_particles}-Mean'] = mean_of_means
        results[f'PF-{n_particles}-StdErr'] = stderr

    if args.flow_run_name and all_model_dirs:
        for n_particles in args.num_nbf_particles:
            print(f'--- Evaluating NBF with {n_particles} particles ---')
            model_mean_jsds = []
            for model_dir_name in all_model_dirs:
                print(f'Evaluating model: {model_dir_name}')
                start_time = time.time()
                key, model_key = jax.random.split(key, 2)
                ckpt_file = select_last_ckpt(
                    os.path.join(args.model_dir, args.flow_run_name, model_dir_name)
                )
                if ckpt_file is None:
                    print(f'  No checkpoint found in {model_dir_name}, skipping.')
                    continue
                model, _ = load_ckpt(ckpt_file, build_continuous_embedding_net, model_key)

                nbf_eval = NeuralBayesFilter(model, num_particles=n_particles)

                total_jsd = 0.0
                for _ in range(args.num_episodes):
                    key, agent_key, ep_key = jax.random.split(key, 3)
                    direction = jax.random.randint(agent_key, shape=(), minval=0, maxval=4)
                    agent = BiasedAgent(direction=direction)
                    jsd = run_episode_eval_nbf(
                        env, nbf_eval, pf_ref, agent, args.max_steps, args.grid_size, ep_key
                    )
                    total_jsd += jsd

                mean_jsd = total_jsd / args.num_episodes
                elapsed = time.time() - start_time
                print(f'  Mean JS Divergence: {mean_jsd:.4f}  ({elapsed:.2f}s)')
                model_mean_jsds.append(float(np.asarray(mean_jsd)))

            if model_mean_jsds:
                mean_of_means = np.mean(model_mean_jsds)
                stderr = np.std(model_mean_jsds) / np.sqrt(len(model_mean_jsds))
                results[f'NBF-{n_particles}-Mean'] = mean_of_means
                results[f'NBF-{n_particles}-StdErr'] = stderr
                print(f'NBF-{n_particles} Mean of Means: {mean_of_means:.4f}')
                print(f'NBF-{n_particles} StdErr: {stderr:.4f}')

    print('\n--- Summary ---')
    for k, v in results.items():
        print(f'Filter: {k:<20} | Value: {v:.4f}')


if __name__ == '__main__':
    main()
