import equinox as eqx
import jax
import jax.numpy as jnp


class GridState(eqx.Module):
    agent_position: jax.Array
    walls: jax.Array


class GridObservation(eqx.Module):
    wall: bool


class GridPolicy(eqx.Module):
    goal: jax.Array
    temp: float

    def __call__(self, key: jax.Array, state: GridState) -> int:
        return jax.random.choice(
            key, jnp.arange(2 * state.agent_position.shape[0]), p=self.action_probs(state)
        )

    def action_probs(self, state: GridState) -> jax.Array:
        def _make_action(idx):
            dim = idx // 2
            step = jnp.where(idx % 2 == 1, 1.0, -1.0)
            a = jnp.zeros(state.agent_position.shape[0])
            return a.at[dim].set(step)

        all_indices = jnp.arange(2 * state.agent_position.shape[0])
        all_actions = jax.vmap(_make_action)(all_indices)
        next_positions = state.agent_position + all_actions
        distances = jnp.sum(jnp.abs(self.goal - next_positions), axis=1)
        logits = -distances * self.temp
        return jax.nn.softmax(logits)


class GridWorld(eqx.Module):
    size: int
    ndim: int
    action_delta: jax.Array

    def __init__(self, size: int, ndim: int) -> None:
        self.size = size
        self.ndim = ndim
        self.action_delta = create_action_displacements(ndim)

    def step(self, state: GridState, action: int) -> tuple[GridState, GridObservation]:
        new_pos = self._update_position(state.agent_position, action)
        flat_pos = jnp.ravel_multi_index(new_pos, state.walls.shape, mode='clip')
        agent_position = jax.lax.cond(
            state.walls.ravel()[flat_pos], lambda: state.agent_position, lambda: new_pos
        )
        return (
            GridState(agent_position, state.walls),
            GridObservation(jnp.all(agent_position == state.agent_position)),
        )

    def reset(self, key: jax.Array, state: GridState) -> GridState:
        init_beliefs = self.initial_beliefs(state)
        start_pos = self.sample_beliefs(key, init_beliefs, 1)[0]
        return GridState(start_pos, state.walls)

    # def actions(self) -> jax.Array:
    #     return jnp.arange(self.ndim * 2)

    # def action_weights(self, intended_action: int) -> jax.Array:
    #     num_actions = self.ndim * 2
    #     p = jnp.ones(num_actions) * self.movement_err / num_actions
    #     return p.at[intended_action].set(1 - (num_actions - 1) * self.movement_err / num_actions)

    def start_positions(self, state: GridState) -> jax.Array:
        return jnp.ones((self.size,) * self.ndim) * (1 - state.walls)

    def initial_beliefs(self, state: GridState) -> jax.Array:
        p = self.start_positions(state)
        return p / jnp.sum(p)

    def sample_beliefs(self, key: jax.Array, beliefs: jax.Array, nsamples: int) -> jax.Array:
        flat_beliefs = beliefs.flatten()
        indices = jax.random.choice(
            key, jnp.arange(len(flat_beliefs)), shape=(nsamples,), p=flat_beliefs
        )
        coords = jnp.unravel_index(indices, beliefs.shape)
        return jnp.stack(coords, axis=-1)

    def update_beliefs(
        self, beliefs: jax.Array, state: GridState, policy: GridPolicy, obs: GridObservation
    ) -> jax.Array:
        pos = generate_cell_indices(self.size, self.ndim)
        policy_states = GridState(pos, jnp.zeros_like(pos))
        action_probs = jax.vmap(policy.action_probs)(policy_states)
        beliefs = jax.lax.cond(
            obs.wall,
            lambda: self._update_collision(beliefs, action_probs, state),
            lambda: self._update_no_collision(beliefs, action_probs),
        )
        beliefs *= 1 - state.walls
        return beliefs / jnp.sum(beliefs)

    def _update_position(self, agent_position: jax.Array, action: int) -> jax.Array:
        pos = agent_position + self.action_delta[action]
        return jnp.clip(pos, 0, self.size - 1)

    def _update_collision(
        self, beliefs: jax.Array, action_probs: jax.Array, state: GridState
    ) -> jax.Array:
        b = jnp.zeros_like(beliefs)
        dim = (self.size,) * self.ndim
        idx = 0
        for axis in range(beliefs.ndim):
            b += (
                action_probs[:, idx].reshape(dim)
                * beliefs
                * self._get_axis_walls(state, axis, True)
            )
            b += (
                action_probs[:, idx + 1].reshape(dim)
                * beliefs
                * self._get_axis_walls(state, axis, False)
            )
            idx += 2
        return b

    def _update_no_collision(self, beliefs: jax.Array, action_probs: jax.Array) -> jax.Array:
        new_beliefs = jnp.zeros_like(beliefs)
        dim = (self.size,) * self.ndim
        for ax in range(self.ndim):
            new_beliefs += shift_along_axis(
                action_probs[:, ax * 2].reshape(dim) * beliefs, ax, False
            )
            new_beliefs += shift_along_axis(
                action_probs[:, ax * 2 + 1].reshape(dim) * beliefs, ax, True
            )
        return new_beliefs

    def _get_axis_walls(self, state: GridState, axis: int, forward: bool = True) -> jax.Array:
        shifted = shift_along_axis(state.walls, axis=axis, forward=forward)
        return add_boundary_wall(shifted, axis=axis, front=forward)


# Grid generation functions
def generate_random_cubes(
    key: jax.Array, size: int, ndim: int, num_cubes: int, cube_width: int
) -> jax.Array:
    walls = jnp.zeros((size,) * ndim)
    for _ in range(num_cubes):
        key, subkey = jax.random.split(key)
        loc = jax.random.randint(subkey, (ndim,), 0, size - cube_width + 1)
        subcube = jnp.ones((cube_width,) * ndim)
        walls = jax.lax.dynamic_update_slice(walls, subcube, loc)
    return walls


def generate_empty_grid(size: int, ndim: int) -> tuple[GridWorld, GridState]:
    grid = GridWorld(size=size, ndim=ndim)
    state = GridState(jnp.zeros(ndim, jnp.int32), jnp.zeros((size,) * ndim))
    return grid, state


def generate_random_grid(
    key: jax.Array, size: int, ndim: int, num_cubes: int = 3, cube_width: int = 3
) -> tuple[GridWorld, GridState]:
    walls = generate_random_cubes(key, size, ndim, num_cubes, cube_width)
    state = GridState(jnp.zeros(ndim, jnp.int32), walls)
    return GridWorld(size=size, ndim=ndim), state


def generate_fixed_grid(
    size: int, ndim: int, num_cubes: int = 3, cube_width: int = 3
) -> tuple[GridWorld, GridState]:
    walls = generate_random_cubes(jax.random.key(0), size, ndim, num_cubes, cube_width)
    state = GridState(jnp.zeros(ndim, jnp.int32), walls)
    return GridWorld(size=size, ndim=ndim), state


# Helper functions
def create_action_displacements(ndim: int) -> jax.Array:
    displacements = []
    for i in range(ndim):
        minus = jnp.zeros(ndim, jnp.int32).at[i].set(-1)
        plus = jnp.zeros(ndim, jnp.int32).at[i].set(1)
        displacements.append(minus)
        displacements.append(plus)
    return jnp.stack(displacements, axis=0)


def shift_along_axis(a: jax.Array, axis: int, forward: bool = True) -> jax.Array:
    def _shift_forward():
        read_start = [0] * a.ndim
        read_limit = list(a.shape)
        read_limit[axis] = a.shape[axis] - 1
        old_part = jax.lax.slice(a, read_start, read_limit)

        out = jnp.zeros_like(a)
        write_start = [0] * a.ndim
        write_start[axis] = 1
        out = jax.lax.dynamic_update_slice(out, old_part, start_indices=write_start)
        return out

    def _shift_backward():
        read_start = [0] * a.ndim
        read_start[axis] = 1
        read_limit = list(a.shape)
        old_part = jax.lax.slice(a, read_start, read_limit)
        out = jnp.zeros_like(a)
        write_start = [0] * a.ndim
        out = jax.lax.dynamic_update_slice(out, old_part, start_indices=write_start)
        return out

    return jax.lax.cond(forward, _shift_forward, _shift_backward)


# Set the edge of an array to 1 along a given axis
def add_boundary_wall(a: jax.Array, axis: int, front: bool) -> jax.Array:
    idx = [slice(None)] * a.ndim
    idx[axis] = jax.lax.cond(front, lambda: 0, lambda: -1)
    return a.at[tuple(idx)].set(1)


def generate_cell_indices(size: int, ndim: int) -> jax.Array:
    coords = jnp.stack(jnp.meshgrid(*[jnp.arange(size)] * ndim, indexing='ij'), axis=-1)
    return coords.reshape((-1, ndim))
