import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import distrax
import optax
from tqdm import tqdm
import argparse
import os
import json

import flowee
import nn
from distribution_embedding import FlowEmbedding, train
from pomdp.gridworld_jax import generate_random_grid, generate_fixed_grid
from pomdp.generate_grid_data import belief_sample_random_walk
from pomdp.utils import base_dir_string

jax.config.update("jax_enable_x64", True)

def parse_args():
    parser = argparse.ArgumentParser()
    # experiment parameters
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")
    parser.add_argument("--dir", default="results", type=str, help="Output directory.")
    parser.add_argument("--checkpoint_interval", default=5000, type=int, help="Checkpoint interval.")
    parser.add_argument("--eval_interval", default=1000, type=int, help="Evaluation interval.")
    # model parameters
    parser.add_argument("--embedding_size", default=32, type=int, help="Embedding size.")
    parser.add_argument("--embedding_hidden_size", default=128, type=int, help="Embeddding hidden size.")
    parser.add_argument("--embedding_nlayers", default=3, type=int, help="Number of embedding layers.")
    parser.add_argument("--flow_hidden_size", default=32, type=int, help="Flow hidden size.")
    parser.add_argument("--flow_nlayers", default=5, type=int, help="Number of flow layers.")
    # training parameters
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--nsamples", default=64, type=int,
                        help="Number of samples for embedding generation.")
    parser.add_argument("--nsteps", default=100000, type=int, help="Number of training steps.")
    parser.add_argument("--lr", default=0.1, type=float, help="Learning rate.")
    # grid parameters
    parser.add_argument("--grid_size", default=5, type=int, help="Grid size.")
    parser.add_argument("--max_walk_depth", default=15, type=int, help="Maximum depth of random walk")
    parser.add_argument("--ndim", default=2, type=int, help="Grid dimensionality.")
    parser.add_argument("--ncubes", default=1, type=int, help="Number of cubes in grid.")
    parser.add_argument("--cube_width", default=2, type=int, help="Width of cubes.")
    parser.add_argument("--policy_temp", default=1., type=float, help="Policy softmax temperature.")
    parser.add_argument("--fixed", default=False, action="store_true", help="Fixed grid and policy.")
    return parser.parse_args()

def dir_string(args):
     s = base_dir_string(args) + "/flow/"
     s += f'e_{args.embedding_size}_{args.embedding_hidden_size}_{args.embedding_nlayers}_'
     s += f'f_{args.flow_hidden_size}_{args.flow_nlayers}_lr_{args.lr}_ns_{args.nsamples}'
     s += f'_bs_{args.batch_size}_s_{args.nsteps}'
     return s

def mask(ndim, i):
        if i % 2 == 0:
            return flowee.checkerboard_mask((ndim,), dtype=jnp.uint8)
        else:
            return flowee.create_mask((ndim,), (3,), dtype=jnp.uint8)

def build_model(key, env, args):
    key, emb_key, deq_key = jax.random.split(key, 3)
    emb_net = eqx.nn.MLP(
        args.ndim, args.embedding_size,
        args.embedding_hidden_size, args.embedding_nlayers,
        key=emb_key
    )
    deq = flowee.Dequantize(
        max_val=env.size - 1, in_dtype=jnp.uint8,
        var_flow=flowee.Coupling(
            mask(args.ndim, 0),
            nn.MultiMLP((args.ndim, args.ndim), 2 * args.ndim, args.flow_hidden_size, 2, key=deq_key),
            dual=True
        )
    )
    key, *flow_keys = jax.random.split(key, 6)
    flow = flowee.Sequential(
        [deq] +
        [
            flowee.Coupling(
                mask(args.ndim, i),
                nn.MLP(args.ndim + args.embedding_size, args.ndim * 5,
                       args.flow_hidden_size, args.flow_nlayers, key=flow_keys[i]),
                flowee.ParameterizedNLSq(mask(args.ndim, 0).shape),
                dual=True
            )
            for i in range(5)
        ] +
        [flowee.Sigmoid(1e-5)]
    )
    flow.add_prior(distrax.Uniform(-1e-4, 1 + 1e-4), (args.ndim,))
    return FlowEmbedding(emb_net, flow)

def checkpoint(args, model, step):
     path = os.path.join(args.dir, dir_string(args), str(args.seed))
     args_path = os.path.join(path, "args.json")
     with open(args_path, "w") as f:
            json.dump(vars(args), f)
     model_path = os.path.join(path, "model", str(step))
     eqx.tree_serialise_leaves(model_path, model)

if __name__ == "__main__":
    args = parse_args()
    key = jax.random.PRNGKey(args.seed)
    key, grid_key = jax.random.split(key, 2)
    if args.fixed:
        env, _ = generate_fixed_grid(args.grid_size, args.ndim, args.ncubes, args.cube_width)
        args.policy_temp = 1e-5
    else:
        # the data generation handles creating new random walls for every element of the batch
        env, _ = generate_random_grid(grid_key, args.grid_size, args.ndim, args.ncubes, args.cube_width)
    os.makedirs(os.path.join(args.dir, dir_string(args), str(args.seed), "model"),
                exist_ok=True)
    model = build_model(key, env, args) 
    opt = optax.adagrad(args.lr)
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def _train_step(key, model, opt_state):
        key, train_key, gen_key = jax.random.split(key, 3)
        xs, _, _ = belief_sample_random_walk(gen_key, env, args.ncubes, args.cube_width,
                                             args.max_walk_depth, args.batch_size, args.nsamples,
                                             args.policy_temp, args.fixed)
        return train(model, opt, opt_state, xs, key=train_key)

    losses = []
    progress_bar = tqdm(range(args.nsteps))
    for i in progress_bar:
        key, step_key = jax.random.split(key, 2)
        model, opt_state, loss_value = _train_step(step_key, model, opt_state)
        losses.append(loss_value)
        if i % args.eval_interval == 0:
            progress_bar.set_postfix({'loss': np.mean(losses)})
            losses = []
        if (i + 1) % args.checkpoint_interval == 0:
            checkpoint(args, model, i + 1)