import argparse
import os

import equinox as eqx
import jax
import jax.numpy as jnp
import optax
from tqdm import tqdm

from envs.triangulation import Triangulation, TriangulationState
from flows.distribution_embedding import train
from global_utils import build_continuous_embedding_net, save_ckpt
from triangulation.filters import ParticleFilter, ParticleSet


def parse_args():
    parser = argparse.ArgumentParser()

    # fmt: off
    # Experiment parameters
    parser.add_argument('--checkpoint_interval', default=500, type=int, help='Number of training steps between saving model checkpoints')
    parser.add_argument('--model_dir', default='triangulation_models', type=str, help='Directory to save model checkpoints')
    parser.add_argument('--run_name', default='flow', type=str, help='Name of the training run')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')

    # Training parameters
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
    parser.add_argument('--nsteps', default=20000, type=int, help='Number of training steps')
    parser.add_argument('--optimizer', default='adam', type=str, help='Optimizer to use')
    parser.add_argument('--max_steps', default=10, type=int, help='Max steps per episode')

    # Model parameters
    parser.add_argument('--embedding_hidden_size', default=128, type=int, help='Hidden size of the embedding network')
    parser.add_argument('--embedding_nlayers', default=3, type=int, help='Number of layers in the embedding network')
    parser.add_argument('--embedding_size', default=32, type=int, help='Output size of the embedding network')
    parser.add_argument('--flow_hidden_size', default=64, type=int, help='Hidden size of the flow network')
    parser.add_argument('--flow_nlayers', default=2, type=int, help='Number of layers in each coupling layer')
    parser.add_argument('--ncoupling_layers', default=6, type=int, help='Number of coupling layers in the flow model')
    parser.add_argument('--num_particles', default=1024, type=int, help='Number of particles in the ground truth particle filter')
    # fmt: on

    return parser.parse_args()


class BiasedAgent(eqx.Module):
    """An agent that is biased towards a specific direction."""

    direction: int

    def __call__(self, state: TriangulationState, key: jax.Array) -> jax.Array:
        return jax.random.choice(key, jnp.array([self.direction, 5]))


@eqx.filter_jit
def generate_episode(
    env: Triangulation,
    agent: BiasedAgent,
    particle_filter: ParticleFilter,
    max_steps: int,
    key: jax.Array,
) -> ParticleSet:
    def step_fn(carry, _):
        state, timestep, filter_state, key = carry

        key, action_key = jax.random.split(key, 2)
        action = agent(state, action_key)

        key, step_key = jax.random.split(key, 2)
        next_state, next_timestep = env.step(state, action, step_key)

        key, update_key = jax.random.split(key, 2)
        next_filter_state, _ = particle_filter.update(
            env,
            filter_state,
            state,
            timestep.observation,
            action,
            next_timestep.observation,
            update_key,
        )

        return (next_state, next_timestep, next_filter_state, key), filter_state

    key, reset_key = jax.random.split(key, 2)
    state, timestep = env.reset(reset_key)

    key, reset_key = jax.random.split(key, 2)
    filter_state = particle_filter.reset(env, state, timestep.observation, reset_key)

    _, filter_states = jax.lax.scan(
        step_fn, (state, timestep, filter_state, key), None, length=max_steps
    )

    return filter_states


@eqx.filter_jit
def generate_batch(
    env: Triangulation,
    agent_func: BiasedAgent,
    particle_filter: ParticleFilter,
    batch_size: int,
    max_steps: int,
    key: jax.Array,
) -> jax.Array:
    key, agent_key = jax.random.split(key, 2)
    agents = jax.vmap(agent_func)(jax.random.randint(agent_key, (batch_size,), minval=0, maxval=4))

    key, generate_key = jax.random.split(key, 2)
    generated_filter_states = jax.vmap(generate_episode, in_axes=(None, 0, None, None, 0))(
        env, agents, particle_filter, max_steps, jax.random.split(generate_key, batch_size)
    )

    # For each episode in the batch, pick one step at random (wasteful but simple)
    key, timestep_key = jax.random.split(key, 2)
    timesteps = jax.random.randint(timestep_key, (batch_size,), minval=0, maxval=max_steps)

    selected_filter_states = jax.vmap(lambda i, ps: jax.tree.map(lambda f: f[i], ps))(
        timesteps, generated_filter_states
    )

    key, resample_key = jax.random.split(key, 2)
    resampled_filter_states = jax.vmap(lambda ps, key: ps.systematic_resample(key))(
        selected_filter_states, jax.random.split(resample_key, batch_size)
    )

    particle_features = jax.vmap(jax.vmap(lambda p: p.to_flat_features()))(
        resampled_filter_states.particles
    )

    return particle_features


def main() -> None:
    args = parse_args()
    key = jax.random.key(args.seed)

    env = Triangulation()
    args.ndim = 3

    key, model_key = jax.random.split(key, 2)
    model = build_continuous_embedding_net(args, model_key)
    particle_filter = ParticleFilter(args.num_particles)

    opt = getattr(optax, args.optimizer)(args.lr)
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def _train_step(model, opt_state, key):
        key, generate_key, train_key = jax.random.split(key, 3)
        batch = generate_batch(
            env, BiasedAgent, particle_filter, args.batch_size, args.max_steps, generate_key
        )

        return train(model, opt, opt_state, batch, train_key)

    os.makedirs(os.path.join(args.model_dir, args.run_name, str(args.seed)), exist_ok=True)
    for i in tqdm(range(args.nsteps)):
        key, train_key = jax.random.split(key, 2)
        model, opt_state, loss = _train_step(model, opt_state, train_key)

        if (i + 1) % args.checkpoint_interval == 0:
            save_ckpt(model, vars(args), i + 1)


if __name__ == '__main__':
    main()
