import argparse
import os
import time

import equinox as eqx
import jax
import numpy as np
import optax
from tqdm import tqdm

from envs.gridworld import generate_fixed_grid, generate_random_grid
from global_utils import save_ckpt
from gridworld.generate_data import belief_action_obs_random_walk
from gridworld.rnn_belief_model import build_rnn_belief_net, train
from gridworld.utils import iqr_filter

jax.config.update('jax_enable_x64', True)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()

    # fmt: off
    # Experiment parameters
    parser.add_argument('--checkpoint_interval', default=5000, type=int, help='Checkpoint interval')
    parser.add_argument('--eval_interval', default=1000, type=int, help='Evaluation interval')
    parser.add_argument('--model_dir', default='gridworld_models', type=str, help='Model directory')
    parser.add_argument('--run_name', default='rnn', type=str, help='Run name')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')
    parser.add_argument('--time_gradient_updates', default=0, type=int, help='Gradient updates timing iterations')

    # Environment parameters
    parser.add_argument('--cube_width', default=2, type=int, help='Width of cubes')
    parser.add_argument('--fixed', default=False, action='store_true', help='Fixed grid and policy')
    parser.add_argument('--grid_size', default=5, type=int, help='Grid size')
    parser.add_argument('--max_walk_depth', default=15, type=int, help='Maximum depth of random walk')
    parser.add_argument('--ncubes', default=1, type=int, help='Number of cubes')
    parser.add_argument('--ndim', default=2, type=int, help='Grid dimensionality')
    parser.add_argument('--policy_temp', default=1.0, type=float, help='Policy softmax temperature')

    # Model parameters
    parser.add_argument('--mlp_hidden_size', default=32, type=int, help='Output head hidden size')
    parser.add_argument('--mlp_nlayers', default=2, type=int, help='Output head number of layers')
    parser.add_argument('--rnn_hidden_size', default=32, type=int, help='Backbone hidden size')
    parser.add_argument('--rnn_nlayers', default=2, type=int, help='Backbone number of layers')

    # Training parameters
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
    parser.add_argument('--lr', default=0.1, type=float, help='Learning rate')
    parser.add_argument('--nsteps', default=100000, type=int, help='Number of training steps')
    parser.add_argument('--optimizer', default='adagrad', type=str, help='Optimizer to use')
    # fmt: on

    return parser.parse_args()


def main() -> None:
    args = parse_args()
    key = jax.random.key(args.seed)

    if args.fixed:
        env, _ = generate_fixed_grid(args.grid_size, args.ndim, args.ncubes, args.cube_width)
        args.policy_temp = 1e-5
    else:
        key, grid_key = jax.random.split(key, 2)
        env, _ = generate_random_grid(
            grid_key, args.grid_size, args.ndim, args.ncubes, args.cube_width
        )

    key, model_key = jax.random.split(key, 2)
    model = build_rnn_belief_net(args, model_key)

    opt = getattr(optax, args.optimizer)(args.lr)
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def _train_step(model, opt_state, key):
        key, gen_key = jax.random.split(key, 2)
        beliefs, obss, _ = belief_action_obs_random_walk(
            gen_key,
            env,
            args.ncubes,
            args.cube_width,
            args.max_walk_depth,
            args.batch_size,
            args.policy_temp,
            args.fixed,
        )

        return train(model, opt, opt_state, obss, beliefs)

    losses, progress_bar = [], tqdm(range(args.nsteps))
    os.makedirs(os.path.join(args.model_dir, args.run_name, str(args.seed)), exist_ok=True)

    for i in progress_bar:
        key, step_key = jax.random.split(key, 2)
        model, opt_state, loss = _train_step(model, opt_state, step_key)
        losses.append(loss)

        if (i + 1) % args.eval_interval == 0:
            progress_bar.set_postfix({'loss': np.mean(losses)})
            losses = []

        if (i + 1) % args.checkpoint_interval == 0:
            save_ckpt(model, vars(args), i + 1)

    # End training by optionally timing gradient updates
    if args.time_gradient_updates > 0:
        # Do one warm-up step to ensure everything is compiled
        key, gen_key = jax.random.split(key, 2)
        beliefs, obss, _ = belief_action_obs_random_walk(
            gen_key,
            env,
            args.ncubes,
            args.cube_width,
            args.max_walk_depth,
            args.batch_size,
            args.policy_temp,
            args.fixed,
        )

        model, opt_state, loss = jax.block_until_ready(train(model, opt, opt_state, obss, beliefs))

        # Start timing
        times = np.zeros(args.time_gradient_updates)
        for i in tqdm(range(args.time_gradient_updates), desc='Timing gradient updates'):
            key, gen_key = jax.random.split(key, 2)
            beliefs, obss, _ = belief_action_obs_random_walk(
                gen_key,
                env,
                args.ncubes,
                args.cube_width,
                args.max_walk_depth,
                args.batch_size,
                args.policy_temp,
                args.fixed,
            )

            start_time = time.perf_counter()
            model, opt_state, loss = jax.block_until_ready(
                train(model, opt, opt_state, obss, beliefs)
            )
            end_time = time.perf_counter()
            times[i] = (end_time - start_time) * 1000  # Convert to milliseconds

        times = iqr_filter(times)
        print(f'Time spent on one update step: ${np.mean(times):.4f} \\pm {np.std(times):.4f}$')


if __name__ == '__main__':
    main()
