import equinox as eqx
import jax
import jax.numpy as jnp

from envs.gridworld import (
    GridPolicy,
    GridState,
    GridWorld,
    generate_fixed_grid,
    generate_random_grid,
)


@eqx.filter_jit
def belief_sample_random_walk(
    key: jax.Array,
    grid: GridWorld,
    num_cubes: int,
    cube_width: int,
    max_steps: int,
    batch_size: int,
    nsamples: int,
    policy_temp: float,
    fixed: bool,
):
    lengths_key, init_key, scan_key, samples_key, goals_key = jax.random.split(key, 5)
    lengths = jax.random.randint(lengths_key, shape=(batch_size,), minval=0, maxval=max_steps + 1)
    goals = jax.random.randint(goals_key, shape=(batch_size, grid.ndim), minval=0, maxval=grid.size)
    policies = jax.vmap(GridPolicy, in_axes=(0, None))(goals, policy_temp)

    def _act(key, policy, state):
        return policy(key, state)

    scan_keys = jax.random.split(scan_key, batch_size * max_steps).reshape(max_steps, batch_size)

    # JIT-friendly to scan over the whole batch until max length
    def _scan_step(carry, t):
        action_keys = scan_keys[t]
        grid_state, belief_state = carry
        step_batch_actions = jax.vmap(_act)(action_keys, policies, grid_state)
        new_state, obs = jax.vmap(grid.step)(grid_state, step_batch_actions)
        new_belief_state = jax.vmap(grid.update_beliefs)(belief_state, grid_state, policies, obs)
        return (new_state, new_belief_state), (new_state, new_belief_state)

    def _make_initial_carry(key: jax.Array):
        if fixed:
            _, init_state = generate_fixed_grid(grid.size, grid.ndim, num_cubes, cube_width)
        else:
            _, init_state = generate_random_grid(key, grid.size, grid.ndim, num_cubes, cube_width)
        return grid.reset(key, init_state), grid.initial_beliefs(init_state)

    init_keys = jax.random.split(init_key, batch_size)
    init_states, init_beliefs = jax.vmap(_make_initial_carry)(init_keys)
    _, all_beliefs = jax.lax.scan(
        f=_scan_step,
        init=(init_states, init_beliefs),
        xs=jnp.arange(max_steps),
    )
    batch_indices = jnp.arange(batch_size)
    final_state_for_each = GridState(
        all_beliefs[0].agent_position[lengths - 1, batch_indices],
        all_beliefs[0].walls[lengths - 1, batch_indices],
    )
    final_belief_for_each = all_beliefs[1][lengths - 1, batch_indices]
    samples_batch_key = jax.random.split(samples_key, batch_size)
    batch = jax.vmap(grid.sample_beliefs, in_axes=(0, 0, None))(
        samples_batch_key, final_belief_for_each, nsamples
    )
    return batch, (final_state_for_each, final_belief_for_each), policies


@eqx.filter_jit
def belief_action_obs_random_walk(
    key: jax.Array,
    grid: GridWorld,
    num_cubes: int,
    cube_width: int,
    max_steps: int,
    batch_size: int,
    policy_temp: float,
    fixed: bool,
):
    init_key, scan_key, goals_key = jax.random.split(key, 3)
    goals = jax.random.randint(goals_key, shape=(batch_size, grid.ndim), minval=0, maxval=grid.size)
    policies = jax.vmap(GridPolicy, in_axes=(0, None))(goals, policy_temp)

    def _act(key, policy, state):
        return policy(key, state)

    scan_keys = jax.random.split(scan_key, batch_size * max_steps).reshape(max_steps, batch_size)

    def _scan_step(carry, t):
        action_keys = scan_keys[t]
        grid_state, belief_state = carry
        step_batch_actions = jax.vmap(_act)(action_keys, policies, grid_state)
        new_state, obs = jax.vmap(grid.step)(grid_state, step_batch_actions)
        new_belief_state = jax.vmap(grid.update_beliefs)(belief_state, grid_state, policies, obs)
        return (new_state, new_belief_state), (belief_state, obs)

    def _make_initial_state(key: jax.Array):
        if fixed:
            _, init_state = generate_fixed_grid(grid.size, grid.ndim, num_cubes, cube_width)
        else:
            _, init_state = generate_random_grid(key, grid.size, grid.ndim, num_cubes, cube_width)
        return grid.reset(key, init_state), grid.initial_beliefs(init_state)

    init_keys = jax.random.split(init_key, batch_size)
    init_states, init_beliefs = jax.vmap(_make_initial_state)(init_keys)
    _, (beliefs, obss) = jax.lax.scan(
        f=_scan_step,
        init=(init_states, init_beliefs),
        xs=jnp.arange(max_steps),
    )
    wall = jnp.expand_dims(jnp.swapaxes(obss.wall, 0, 1), -1)
    return jnp.swapaxes(beliefs, 0, 1).reshape((batch_size, max_steps, -1)), wall, policies
