import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import os
import json
import argparse
import pickle
from typing import Sequence

from pomdp.gridworld_jax import generate_random_grid, GridState, generate_fixed_grid
from pomdp.filter import *
from pomdp.train_grid_flow import build_model as build_flow
from pomdp.train_grid_rnn import build_model as build_rnn
from pomdp.utils import *
from pomdp.generate_grid_data import belief_sample_random_walk
from pomdp.grid_results import load_combined_stats, print_stats, plot_stats
import glob


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="results/", type=str, help="Directory to load model.")
    parser.add_argument("--seed", default=42, type=int, help="Training seed to load associated model.")
    parser.add_argument("--checkpoint", default=-1, type=int, help="Checkpoint to load.")
    parser.add_argument("--k_acc", default=5, type=int, help="Top-k accuracy.")
    parser.add_argument("--max_init_steps", default=1, type=int, help="Maximum number of initial steps.")
    parser.add_argument("--nsteps", default=10, type=int, help="Maximum number of filtering steps.")
    parser.add_argument("--nrepeats", default=500, type=int, help="Number of repeats.")
    parser.add_argument("--nparticles_nf", default=16, nargs="+", type=int, help="Number of particles for Neural Filter.")
    parser.add_argument("--nparticles_pf", default=16, nargs="+", type=int, help="Number of particles for Particle Filter.")
    parser.add_argument("--show_legend", default=False, action="store_true", help="Show legend in plots.")
    return parser.parse_args()

def get_latest_checkpoint(path: str) -> int:
    return max([int(f.split(".")[0]) for f in os.listdir(path) if f.endswith(".eqx")])

def load_checkpoint(args: argparse.Namespace, model: FlowEmbedding, model_prefix: str = "flow") -> FlowEmbedding:
    path_pattern = os.path.join(args.dir, model_prefix, "*", f"{args.seed}", "model")
    base_path = glob.glob(path_pattern)[0]
    if args.checkpoint < 0:
        args.checkpoint = get_latest_checkpoint(base_path)
    path = os.path.join(base_path, f"{str(args.checkpoint)}.eqx")
    with open(path, "rb") as f:
        model = eqx.tree_deserialise_leaves(f, model)
    return model

def load_training_args(args: argparse.Namespace, model_prefix: str = "flow") -> argparse.Namespace:
    path_pattern = os.path.join(args.dir, model_prefix, "*", "*", "args.json")
    files = glob.glob(path_pattern)
    if not files:
        raise FileNotFoundError(f"No {model_prefix} model found.")
    path = files[0]
    with open(path, "r") as f:
        return argparse.Namespace(**json.load(f))

def write_results(eval_args: argparse.Namespace, stats: dict):
    base_path = os.path.join(args.dir, "eval")
    path = os.path.join(base_path, str(eval_args.seed))
    os.makedirs(path, exist_ok=True)
    args_path = os.path.join(path, "eval_args.json")
    results = os.path.join(path, "results.pkl")
    with open(args_path, "w") as f:
        json.dump(vars(eval_args), f)
    with open(results, "wb") as f:
        pickle.dump(stats, f)
    eval_args.dir = base_path
    stats = load_combined_stats(eval_args)
    print_stats(stats)
    plot_stats(eval_args, stats)

def filter_belief_dist(key: jax.random.PRNGKey, env: eqx.Module, filters: dict, args: argparse.Namespace,
                       stats: dict):
    key, init_step_key, reset_key, stat_key = jax.random.split(key, 4)
    _, (states, beliefs), policies = belief_sample_random_walk(init_step_key, env, args.ncubes,
                                                               args.cube_width, args.max_init_steps,
                                                               args.nrepeats, 1, args.policy_temp,
                                                               args.fixed)
    reset_keys = jax.random.split(reset_key, len(filters))
    filter_states = [jax.vmap(f.reset, in_axes=(None, 0, 0))(k, beliefs, states)
                     for f, k in zip(filters.values(), reset_keys)]
    ret_filter_states = []
    stat_key, sub = jax.random.split(stat_key)
    compute_stats(sub, stats, args, filters, filter_states, states, 0)

    def _apply_policy(key, policy, state):
        return policy(key, state)

    for i in range(args.nsteps):
        key, action_key = jax.random.split(key, 2)
        action_keys = jax.random.split(action_key, args.nrepeats)
        actions = jax.vmap(_apply_policy)(action_keys, policies, states)
        states, obss = jax.vmap(env.step)(states, actions)
        ret_filter_states = []
        for filter, filter_state in zip(filters.values(), filter_states):
            key, *update_keys = jax.random.split(key, args.nrepeats + 1)
            ret_filter_states.append(jax.vmap(filter.update)(jnp.array(update_keys),
                                                             filter_state, states, policies, obss))
        compute_stats(stat_key, stats, args, filters, ret_filter_states, states, i + 1)
        filter_states = ret_filter_states
    return stats, ret_filter_states


v_acc = eqx.filter_jit(jax.vmap(top_k_correct, in_axes=(0, None, 0)))
v_nll = eqx.filter_jit(jax.vmap(nll))
v_js = eqx.filter_jit(jax.vmap(js_divergence))

def compute_stats(key: jax.random.PRNGKey, stats: dict, args: argparse.Namespace,
                  filters: Sequence[eqx.Module], filter_states: Sequence[jax.Array],
                  states: GridState, step: int):
    for name, filter, filter_state in zip(filters.keys(), filters.values(), filter_states):
        key, *belief_keys = jax.random.split(key, args.nrepeats + 1)
        beliefs = jax.vmap(filter.compute_beliefs)(jnp.array(belief_keys), filter_state)
        stats[f"Top-{args.k_acc} Accuracy"][name][step] = v_acc(beliefs, args.k_acc, states.agent_pos)
        stats["Negative Log Likelihood"][name][step] = v_nll(beliefs, states.agent_pos)
        if name == "True Beliefs":
            env_beliefs = beliefs
        else:
            stats["JS Divergence"][name][step] = v_js(beliefs, env_beliefs)


if __name__ == "__main__":
    args = parse_args()
    flow_train_args = load_training_args(args, "flow")
    rnn_train_args = load_training_args(args, "rnn")
    args.cube_width = flow_train_args.cube_width
    args.ncubes = flow_train_args.ncubes
    args.policy_temp = flow_train_args.policy_temp
    args.fixed = flow_train_args.fixed
    key = jax.random.PRNGKey(args.seed)
    key, gen_key, flow_key, rnn_key = jax.random.split(key, 4)
    if args.fixed:
        env, _ = generate_fixed_grid(flow_train_args.grid_size, flow_train_args.ndim,
                                     flow_train_args.ncubes, flow_train_args.cube_width)
    else:
        env, _ = generate_random_grid(gen_key, size=flow_train_args.grid_size, ndim=flow_train_args.ndim,
                                      num_cubes=flow_train_args.ncubes, cube_width=flow_train_args.cube_width)
    flow_model = build_flow(flow_key, env, flow_train_args)
    rnn_model = build_rnn(rnn_key, env, rnn_train_args)
    flow_model = load_checkpoint(args, flow_model, "flow")
    rnn_model = load_checkpoint(args, rnn_model, "rnn")
    pf_names = [f"PF ({nparticles})" for nparticles in args.nparticles_pf]
    nf_names = [f"NBF ({nparticles})" for nparticles in args.nparticles_nf]
    pfs = [ParticleFilter(env, nparticles) for nparticles in args.nparticles_pf]
    nfs = [NeuralBayesFilter(env, flow_model, nparticles) for nparticles in args.nparticles_nf]
    filters = {
        "True Beliefs": GroundTruthFilter(env),
        "Approx Beliefs": NeuralGTFilter(env, flow_model, nparticles=args.nparticles_nf[-1]),
        "Recurrent": RNNFilter(env, rnn_model, obs_shape=1)
    }
    for name, pf in zip(pf_names, pfs):
        filters[name] = pf
    for name, nf in zip(nf_names, nfs):
        filters[name] = nf
    stats = {}
    stats[f"Top-{args.k_acc} Accuracy"] = {k: np.zeros((args.nsteps + 1, args.nrepeats)) for k in filters.keys()}
    stats["JS Divergence"] = {k: np.zeros((args.nsteps + 1, args.nrepeats)) for k in filters.keys()}
    stats["Negative Log Likelihood"] = {k: np.zeros((args.nsteps + 1, args.nrepeats)) for k in filters.keys()}
    filter_belief_dist(key, env, filters, args, stats)
    write_results(args, stats)