import argparse
import os

import equinox as eqx
import jax
import numpy as np
import optax
from tqdm import tqdm

from envs.gridworld import generate_fixed_grid, generate_random_grid
from flows.distribution_embedding import train
from global_utils import build_discrete_embedding_net, save_ckpt
from gridworld.generate_data import belief_sample_random_walk

jax.config.update('jax_enable_x64', True)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()

    # fmt: off
    # Experiment parameters
    parser.add_argument('--checkpoint_interval', default=5000, type=int, help='Checkpoint interval')
    parser.add_argument('--eval_interval', default=1000, type=int, help='Evaluation interval')
    parser.add_argument('--model_dir', default='gridworld_models', type=str, help='Model directory')
    parser.add_argument('--run_name', default='flow', type=str, help='Run name')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')

    # Environment parameters
    parser.add_argument('--cube_width', default=2, type=int, help='Width of cubes')
    parser.add_argument('--fixed', default=False, action='store_true', help='Fixed grid and policy')
    parser.add_argument('--grid_size', default=5, type=int, help='Grid size')
    parser.add_argument('--max_walk_depth', default=15, type=int, help='Maximum depth of random walk')
    parser.add_argument('--ncubes', default=1, type=int, help='Number of cubes')
    parser.add_argument('--ndim', default=2, type=int, help='Grid dimensionality')
    parser.add_argument('--policy_temp', default=1.0, type=float, help='Policy softmax temperature')

    # Model parameters
    parser.add_argument('--embedding_hidden_size', default=128, type=int, help='Embeddding hidden size')
    parser.add_argument('--embedding_nlayers', default=3, type=int, help='Number of embedding layers')
    parser.add_argument('--embedding_size', default=32, type=int, help='Embedding size')
    parser.add_argument('--flow_hidden_size', default=32, type=int, help='Flow hidden size')
    parser.add_argument('--flow_nlayers', default=5, type=int, help='Number of flow layers')
    parser.add_argument('--ncoupling_layers', default=5, type=int, help='Number of coupling layers')

    # Training parameters
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
    parser.add_argument('--lr', default=0.1, type=float, help='Learning rate')
    parser.add_argument('--nsamples', default=64, type=int, help='Number of samples per belief')
    parser.add_argument('--nsteps', default=100000, type=int, help='Number of training steps')
    parser.add_argument('--optimizer', default='adagrad', type=str, help='Optimizer to use')
    # fmt: on

    return parser.parse_args()


def main():
    args = parse_args()
    key = jax.random.key(args.seed)

    if args.fixed:
        env, _ = generate_fixed_grid(args.grid_size, args.ndim, args.ncubes, args.cube_width)
        args.policy_temp = 1e-5
    else:
        key, grid_key = jax.random.split(key, 2)
        env, _ = generate_random_grid(
            grid_key, args.grid_size, args.ndim, args.ncubes, args.cube_width
        )

    args.max_val = args.grid_size - 1
    args.flow_squash_sigmoid = True
    args.uniform_prior = True

    key, model_key = jax.random.split(key, 2)
    model = build_discrete_embedding_net(args, model_key)

    opt = getattr(optax, args.optimizer)(args.lr)
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def _train_step(model, opt_state, key):
        key, generate_key, train_key = jax.random.split(key, 3)
        xs, *_ = belief_sample_random_walk(
            generate_key,
            env,
            args.ncubes,
            args.cube_width,
            args.max_walk_depth,
            args.batch_size,
            args.nsamples,
            args.policy_temp,
            args.fixed,
        )

        return train(model, opt, opt_state, xs, train_key)

    losses, progress_bar = [], tqdm(range(args.nsteps))
    os.makedirs(os.path.join(args.model_dir, args.run_name, str(args.seed)), exist_ok=True)

    for i in progress_bar:
        key, step_key = jax.random.split(key, 2)
        model, opt_state, loss = _train_step(model, opt_state, step_key)
        losses.append(loss)

        if (i + 1) % args.eval_interval == 0:
            progress_bar.set_postfix({'loss': np.mean(losses)})
            losses = []

        if (i + 1) % args.checkpoint_interval == 0:
            save_ckpt(model, vars(args), i + 1)


if __name__ == '__main__':
    main()
