import argparse
import os
import pickle

import jax
import jax.numpy as jnp
import numpy as np

from envs.gridworld import GridState, GridWorld, generate_fixed_grid, generate_random_grid
from global_utils import build_discrete_embedding_net, load_ckpt, select_last_ckpt
from gridworld.filters import (
    BaseFilter,
    GroundTruthFilter,
    NeuralBayesFilter,
    NeuralGTFilter,
    ParticleFilter,
    ParticleSet,
    RNNFilter,
)
from gridworld.generate_data import belief_sample_random_walk
from gridworld.process_results import plot_stats, print_stats
from gridworld.rnn_belief_model import build_rnn_belief_net
from gridworld.utils import (
    js_divergence,
    negative_log_likelihood,
    top_k_accuracy,
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()

    # fmt: off
    parser.add_argument('--flow_run_name', default='flow', type=str, help='Flow training nun name')
    parser.add_argument('--model_dir', default='gridworld_models', type=str, help='Directory name')
    parser.add_argument('--rnn_run_name', default='rnn', type=str, help='RNN training run name')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')

    parser.add_argument('--k_acc', default=5, type=int, help='Top-k accuracy')
    parser.add_argument('--max_init_steps', default=1, type=int, help='Maximum number of initial steps')
    parser.add_argument('--nparticles_nf', default=[16], type=int, nargs='+', help='Number of particles for NBF')
    parser.add_argument('--nparticles_pf', default=[16], type=int, nargs='+', help='Number of particles for PF')
    parser.add_argument('--nrepeats', default=500, type=int, help='Number of repeats')
    parser.add_argument('--nsteps', default=10, type=int, help='Maximum number of filtering steps')
    parser.add_argument('--show_legend', default=False, action='store_true', help='Show legend in plots')
    # fmt: on

    return parser.parse_args()


def evaluate(
    key: jax.Array, env: GridWorld, filters: dict[str, BaseFilter], args: argparse.Namespace
) -> dict[str, dict[str, np.ndarray]]:
    stats = {
        f'Top-{args.k_acc} Accuracy': {
            f: np.zeros((args.nsteps + 1, args.nrepeats)) for f in filters
        },
        'Negative Log Likelihood': {f: np.zeros((args.nsteps + 1, args.nrepeats)) for f in filters},
        'JS Divergence': {f: np.zeros((args.nsteps + 1, args.nrepeats)) for f in filters},
    }

    key, sample_key = jax.random.split(key, 2)
    _, (states, beliefs), policies = belief_sample_random_walk(
        sample_key,
        env,
        args.ncubes,
        args.cube_width,
        args.max_init_steps,
        args.nrepeats,
        1,
        args.policy_temp,
        args.fixed,
    )

    key, *reset_keys = jax.random.split(key, 1 + len(filters))
    filter_states = [
        jax.vmap(f.reset, in_axes=(None, 0, 0))(k, beliefs, states)
        for f, k in zip(filters.values(), jnp.array(reset_keys))
    ]

    key, stats_key = jax.random.split(key, 2)
    stats = compute_stats(stats_key, stats, args, filters, filter_states, states, 0)

    def _apply_policy(key, policy, state):
        return policy(key, state)

    for i in range(args.nsteps):
        key, *action_keys = jax.random.split(key, 1 + args.nrepeats)
        actions = jax.vmap(_apply_policy)(jnp.array(action_keys), policies, states)
        states, obss = jax.vmap(env.step)(states, actions)

        ret_filter_states = []
        for filter, filter_state in zip(filters.values(), filter_states):
            key, *update_keys = jax.random.split(key, 1 + args.nrepeats)
            ret_filter_states.append(
                jax.vmap(filter.update)(
                    jnp.array(update_keys), filter_state, states, policies, obss
                )
            )

        key, stat_key = jax.random.split(key, 2)
        stats = compute_stats(stat_key, stats, args, filters, ret_filter_states, states, i + 1)
        filter_states = ret_filter_states

    return stats


def compute_stats(
    key: jax.Array,
    stats: dict[str, dict[str, np.ndarray]],
    args: argparse.Namespace,
    filters: dict[str, BaseFilter],
    filter_states: list[ParticleSet | jax.Array],
    states: GridState,
    step: int,
) -> dict[str, dict[str, np.ndarray]]:
    for name, filter, filter_state in zip(filters.keys(), filters.values(), filter_states):
        key, *belief_keys = jax.random.split(key, 1 + args.nrepeats)
        beliefs = jax.vmap(filter.compute_beliefs)(jnp.array(belief_keys), filter_state)

        stats[f'Top-{args.k_acc} Accuracy'][name][step] = jax.vmap(
            top_k_accuracy, in_axes=(0, None, 0)
        )(beliefs, args.k_acc, states.agent_position)
        stats['Negative Log Likelihood'][name][step] = jax.vmap(negative_log_likelihood)(
            beliefs, states.agent_position
        )

        if name == 'True Beliefs':
            env_beliefs = beliefs
        else:
            stats['JS Divergence'][name][step] = jax.vmap(js_divergence)(beliefs, env_beliefs)

    return stats


def save_stats(args: argparse.Namespace, stats: dict[str, jax.Array]) -> str:
    results_dir = os.path.join(
        args.model_dir, f'eval_{args.flow_run_name}_{args.rnn_run_name}', str(args.seed)
    )
    os.makedirs(results_dir, exist_ok=True)

    with open(os.path.join(results_dir, 'results.pkl'), 'wb') as f:
        pickle.dump(stats, f)

    return results_dir


def main():
    args = parse_args()
    key = jax.random.key(args.seed)

    key, flow_key = jax.random.split(key, 2)
    flow_ckpt_file = select_last_ckpt(
        os.path.join(args.model_dir, args.flow_run_name, str(args.seed))
    )
    flow_model, flow_train_args = load_ckpt(flow_ckpt_file, build_discrete_embedding_net, flow_key)

    key, rnn_key = jax.random.split(key, 2)
    rnn_ckpt_file = select_last_ckpt(
        os.path.join(args.model_dir, args.rnn_run_name, str(args.seed))
    )
    rnn_model, _ = load_ckpt(rnn_ckpt_file, build_rnn_belief_net, rnn_key)

    args.fixed = flow_train_args.fixed
    args.grid_size = flow_train_args.grid_size
    args.ndim = flow_train_args.ndim
    args.ncubes = flow_train_args.ncubes
    args.cube_width = flow_train_args.cube_width
    args.policy_temp = flow_train_args.policy_temp
    args.nstates = flow_train_args.grid_size**flow_train_args.ndim

    if args.fixed:
        env, _ = generate_fixed_grid(args.grid_size, args.ndim, args.ncubes, args.cube_width)
    else:
        key, grid_key = jax.random.split(key, 2)
        env, _ = generate_random_grid(
            grid_key, args.grid_size, args.ndim, args.ncubes, args.cube_width
        )

    filters = {
        'True Beliefs': GroundTruthFilter(env),
        'Approx Beliefs': NeuralGTFilter(env, flow_model, args.nparticles_nf[-1]),
        'Recurrent': RNNFilter(env, rnn_model, obs_shape=1),
    }

    for num_particles in args.nparticles_pf:
        filters[f'PF ({num_particles})'] = ParticleFilter(env, num_particles)

    for num_particles in args.nparticles_nf:
        filters[f'NBF ({num_particles})'] = NeuralBayesFilter(env, flow_model, num_particles)

    stats = evaluate(key, env, filters, args)
    results_dir = save_stats(args, stats)
    args.results_dir = results_dir

    print_stats(stats)
    plot_stats(args, stats)


if __name__ == '__main__':
    main()
