import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import optax
from tqdm import tqdm
import argparse
import os
import json
from typing import Tuple

import nn
from belief_model import RNNBeliefModel, train
from pomdp.gridworld_jax import generate_random_grid, generate_fixed_grid
from pomdp.generate_grid_data import belief_action_obs_random_walk
from pomdp.utils import base_dir_string

jax.config.update("jax_enable_x64", True)

def parse_args():
    parser = argparse.ArgumentParser()
    # experiment parameters
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")
    parser.add_argument("--dir", default="results", type=str, help="Output directory.")
    parser.add_argument("--checkpoint_interval", default=5000, type=int, help="Checkpoint interval.")
    parser.add_argument("--eval_interval", default=1000, type=int, help="Evaluation interval.")
    # model parameters
    parser.add_argument("--rnn_nlayers", default=2, type=int, help="Number of embedding layers.")
    parser.add_argument("--rnn_hidden_size", default=32, type=int, help="rnn hidden size.")
    parser.add_argument("--mlp_hidden_size", default=32, type=int, help="output head hidden size.")
    parser.add_argument("--mlp_nlayers", default=2, type=int, help="Number of flow layers.")
    # training parameters
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--nsteps", default=100000, type=int, help="Number of training steps.")
    parser.add_argument("--lr", default=0.1, type=float, help="Learning rate.")
    # grid parameters
    parser.add_argument("--grid_size", default=5, type=int, help="Grid size.")
    parser.add_argument("--max_walk_depth", default=15, type=int, help="Maximum depth of random walk")
    parser.add_argument("--ndim", default=2, type=int, help="Grid dimensionality.")
    parser.add_argument("--ncubes", default=1, type=int, help="Number of cubes in grid.")
    parser.add_argument("--cube_width", default=2, type=int, help="Width of cubes.")
    parser.add_argument("--policy_temp", default=1., type=float, help="Policy softmax temperature.")
    parser.add_argument("--fixed", default=False, action="store_true", help="Fixed grid and policy.")
    return parser.parse_args()

def dir_string(args):
    s = base_dir_string(args) + "/rnn/"
    s += f'rnn_{args.rnn_nlayers}_{args.rnn_hidden_size}_'
    s += f'm_{args.mlp_nlayers}_{args.mlp_hidden_size}_lr_{args.lr}_bs_{args.batch_size}'
    s += f'_s_{args.nsteps}'
    return s

def build_model(key, env, args):
    rnn_key, head_key = jax.random.split(key, 2)
    rnn = nn.MultiLayerLSTM(1, args.rnn_nlayers, args.rnn_hidden_size,
                           key=rnn_key)
    head = eqx.nn.MLP(args.rnn_hidden_size, env.size ** env.ndim, args.mlp_hidden_size,
                      args.mlp_nlayers, key=head_key)
    return RNNBeliefModel(rnn=rnn, head=head)

def checkpoint(args, model, step):
    path = os.path.join(args.dir, dir_string(args), str(args.seed))
    args_path = os.path.join(path, "args.json")
    with open(args_path, "w") as f:
        json.dump(vars(args), f)
    model_path = os.path.join(path, "model", str(step))
    eqx.tree_serialise_leaves(model_path, model)

if __name__ == "__main__":
    args = parse_args()
    key = jax.random.PRNGKey(args.seed)
    key, model_key, grid_key = jax.random.split(key, 3)
    if args.fixed:
        env, _ = generate_fixed_grid(args.grid_size, args.ndim, args.ncubes, args.cube_width)
        args.policy_temp = 1e-5
    else:
        env, _ = generate_random_grid(grid_key, args.grid_size, args.ndim, args.ncubes, args.cube_width)
    os.makedirs(os.path.join(args.dir, dir_string(args), str(args.seed), "model"),
                exist_ok=True)
    num_actions = env.actions().shape[0]
    model = build_model(model_key, env, args)
    opt = optax.adagrad(args.lr)
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def _train_step(key, model, opt_state):
        key, gen_key = jax.random.split(key, 2)
        beliefs, obss, _ = belief_action_obs_random_walk(gen_key, env, args.ncubes, args.cube_width,
                                                         args.max_walk_depth, args.batch_size,
                                                         args.policy_temp, args.fixed)
        return train(model, opt, opt_state, obss, beliefs)
    
    losses = []
    progress_bar = tqdm(range(args.nsteps))
    for i in progress_bar:
        key, step_key = jax.random.split(key, 2)
        model, opt_state, loss_value = _train_step(step_key, model, opt_state)
        losses.append(loss_value)
        if i % args.eval_interval == 0:
            progress_bar.set_postfix({'loss': np.mean(losses)})
            losses = []
        if (i + 1) % args.checkpoint_interval == 0:
            checkpoint(args, model, i + 1)