import jax
import jax.numpy as jnp
import equinox as eqx

from pomdp.gridworld_jax import *

# returns nsamples from a batch of belief states
@eqx.filter_jit
def belief_sample_random_walk(key: jax.random.PRNGKey, grid: GridWorld, num_cubes: int,
                              cube_width: int, max_steps: int, batch_size: int, nsamples: int,
                              policy_temp: float, fixed: bool):
    lengths_key, init_key, scan_key, samples_key, goals_key = jax.random.split(key, 5)
    lengths = jax.random.randint(lengths_key, shape=(batch_size,), minval=0,
                                 maxval=max_steps + 1)
    goals = jax.random.randint(goals_key, shape=(batch_size, grid.ndim), minval=0,
                               maxval=grid.size)
    policies = jax.vmap(GridPolicy, in_axes=(0, None))(goals, policy_temp)

    def _act(key, policy, state):
        return policy(key, state)

    scan_keys = jax.random.split(
        scan_key,
        batch_size * max_steps
    ).reshape(max_steps, batch_size, 2)

    # JIT-friendly to scan over the whole batch until max length
    def _scan_step(carry, t):
        action_keys = scan_keys[t]
        grid_state, belief_state = carry
        step_batch_actions = jax.vmap(_act)(action_keys, policies, grid_state)
        new_state, obs = jax.vmap(grid.step)(grid_state, step_batch_actions)
        new_belief_state = jax.vmap(grid.update_beliefs)(belief_state, grid_state, policies, obs)
        return (new_state, new_belief_state), (new_state, new_belief_state)

    def _make_initial_carry(key: jax.random.PRNGKey):
        if fixed:
            _, init_state = generate_fixed_grid(grid.size, grid.ndim, num_cubes, cube_width)
        else:
            _, init_state = generate_random_grid(key, grid.size, grid.ndim, num_cubes, cube_width)
        return grid.reset(key, init_state), grid.initial_beliefs(init_state)

    init_keys = jax.random.split(init_key, batch_size)
    init_states, init_beliefs = jax.vmap(_make_initial_carry)(init_keys)
    _, all_beliefs = jax.lax.scan(
        f=_scan_step,
        init=(init_states, init_beliefs),
        xs=jnp.arange(max_steps),
    )
    batch_indices = jnp.arange(batch_size)
    final_state_for_each = GridState(agent_pos=all_beliefs[0].agent_pos[lengths - 1, batch_indices],
                                     walls=all_beliefs[0].walls[lengths - 1, batch_indices])
    final_belief_for_each = all_beliefs[1][lengths - 1, batch_indices]
    samples_batch_key = jax.random.split(samples_key, batch_size)
    batch = jax.vmap(grid.sample_beliefs,
                     in_axes=(0, 0, None))(samples_batch_key, final_belief_for_each, nsamples)
    return batch, (final_state_for_each, final_belief_for_each), policies


# returns a batch of belief states and the corresponding sequence of observations
@eqx.filter_jit
def belief_action_obs_random_walk(key: jax.random.PRNGKey, grid: GridWorld, num_cubes: int,
                cube_width: int, max_steps: int, batch_size: int, policy_temp: float, fixed: bool):
    init_key, scan_key, goals_key = jax.random.split(key, 3)
    goals = jax.random.randint(goals_key, shape=(batch_size, grid.ndim), minval=0,
                               maxval=grid.size)
    policies = jax.vmap(GridPolicy, in_axes=(0, None))(goals, policy_temp)

    def _act(key, policy, state):
        return policy(key, state)

    scan_keys = jax.random.split(
        scan_key,
        batch_size * max_steps
    ).reshape(max_steps, batch_size, 2)

    def _scan_step(carry, t):
        action_keys = scan_keys[t]
        grid_state, belief_state = carry
        step_batch_actions = jax.vmap(_act)(action_keys, policies, grid_state)
        new_state, obs = jax.vmap(grid.step)(grid_state, step_batch_actions)
        new_belief_state = jax.vmap(grid.update_beliefs)(belief_state, grid_state, policies, obs)
        return (new_state, new_belief_state), (belief_state, obs)

    def _make_initial_state(key: jax.random.PRNGKey):
        if fixed:
            _, init_state = generate_fixed_grid(grid.size, grid.ndim, num_cubes, cube_width)
        else:
            _, init_state = generate_random_grid(key, grid.size, grid.ndim, num_cubes, cube_width)
        return grid.reset(key, init_state), grid.initial_beliefs(init_state)

    init_keys = jax.random.split(init_key, batch_size)
    init_states, init_beliefs = jax.vmap(_make_initial_state)(init_keys)
    _, (beliefs, obss) = jax.lax.scan(
        f=_scan_step,
        init=(init_states, init_beliefs),
        xs=jnp.arange(max_steps),
    )
    wall = jnp.expand_dims(jnp.swapaxes(obss.wall, 0, 1), -1)
    return jnp.swapaxes(beliefs, 0, 1).reshape((batch_size, max_steps, -1)), wall, policies