import equinox as eqx
import jax
import jax.numpy as jnp


class GridState(eqx.Module):
    agent_pos: jax.Array
    walls: jax.Array


class GridPolicy(eqx.Module):
    goal: jax.Array
    temp: float

    def action_probs(self, state: GridState):
        def make_action(idx):
            dim = idx // 2
            step = jnp.where(idx % 2 == 1, 1.0, -1.0)
            a = jnp.zeros(state.agent_pos.shape[0])
            return a.at[dim].set(step)

        all_indices = jnp.arange(2 * state.agent_pos.shape[0])
        all_actions = jax.vmap(make_action)(all_indices)
        next_positions = state.agent_pos + all_actions
        distances = jnp.sum(jnp.abs(self.goal - next_positions), axis=1)
        logits = -distances * self.temp
        return jax.nn.softmax(logits)

    def __call__(self, key, state: GridState):
        return jax.random.choice(key, jnp.arange(2 * state.agent_pos.shape[0]),
                                 p=self.action_probs(state))


class GridObs(eqx.Module):
    wall: bool


class GridWorld(eqx.Module):
    size: int
    ndim: int
    action_delta: jax.Array

    def __init__(self, size: int, ndim: int):
        self.size = size
        self.ndim = ndim
        self.action_delta = create_action_displacements(ndim)

    def reset(self, key: jax.random.PRNGKey, state: GridState) -> GridState:
        init_beliefs = self.initial_beliefs(state)
        start_pos = self.sample_beliefs(key, init_beliefs, 1)[0]
        return GridState(agent_pos=start_pos, walls=state.walls)

    def step(self, state: GridState, action: int) -> tuple[GridState, GridObs]:
        new_pos = self._update_position(state.agent_pos, action)
        flat_pos = jnp.ravel_multi_index(new_pos, state.walls.shape, mode='clip')
        agent_pos = jax.lax.cond(state.walls.ravel()[flat_pos],
                                 lambda: state.agent_pos, lambda: new_pos)
        return (GridState(agent_pos=agent_pos, walls=state.walls),
                GridObs(wall=jnp.all(agent_pos == state.agent_pos)))

    def _update_position(self, agent_pos: jnp.ndarray, action: int):
        pos = agent_pos + self.action_delta[action]
        return jnp.clip(pos, 0, self.size - 1)

    def action_weights(self, intended_action):
        num_actions = self.ndim * 2
        p = jnp.ones(num_actions) * self.movement_err / num_actions
        return p.at[intended_action].set(1 - (num_actions - 1) * self.movement_err / num_actions)

    def sample_beliefs(self, key: jax.random.PRNGKey, beliefs: jnp.ndarray, nsamples: int) -> jax.Array:
        flat_beliefs = beliefs.flatten()
        indices = jax.random.choice(key, jnp.arange(len(flat_beliefs)),
                                    shape=(nsamples,), p=flat_beliefs)
        coords = jnp.unravel_index(indices, beliefs.shape)
        return jnp.stack(coords, axis=-1)

    def initial_beliefs(self, state: GridState) -> jax.Array:
        p = self.start_positions(state)
        return p / jnp.sum(p)

    def update_beliefs(self, beliefs: int, state: GridState, policy: GridPolicy, obs: GridObs) -> jax.Array:
        pos = generate_cell_indices(self.size, self.ndim)
        policy_states = GridState(agent_pos=pos, walls=jnp.zeros_like(pos))
        action_probs = jax.vmap(policy.action_probs)(policy_states)
        beliefs = jax.lax.cond(obs.wall,
                               lambda: self._update_collision(beliefs, action_probs, state),
                               lambda: self._update_no_collision(beliefs, action_probs))
        beliefs *= (1 - state.walls)
        return beliefs / jnp.sum(beliefs)

    def _update_no_collision(self, beliefs, action_probs):
        new_beliefs = jnp.zeros_like(beliefs)
        dim = (self.size,) * self.ndim
        for ax in range(self.ndim):
            new_beliefs += shift_along_axis(action_probs[:, ax * 2].reshape(dim) * beliefs, ax, False) 
            new_beliefs += shift_along_axis(action_probs[:, ax * 2 + 1].reshape(dim) * beliefs, ax, True)
        return new_beliefs
    
    def _get_axis_walls(self, state: GridState, axis: int, forward: bool = True) -> jnp.ndarray:
        shifted = shift_along_axis(state.walls, axis=axis, forward=forward)
        return add_boundary_wall(shifted, axis=axis, front=forward)

    def _update_collision(self, beliefs, action_probs, state):
        b = jnp.zeros_like(beliefs)
        dim = (self.size,) * self.ndim
        idx = 0
        for axis in range(beliefs.ndim):
            b += action_probs[:, idx].reshape(dim) * beliefs * self._get_axis_walls(state, axis, True)
            b += action_probs[:, idx + 1].reshape(dim) * beliefs * self._get_axis_walls(state, axis, False)
            idx += 2
        return b

    def start_positions(self, state: GridState) -> jax.Array:
        return jnp.ones((self.size,) * self.ndim) * (1 - state.walls)
    
    def actions(self) -> jax.Array:
        return jnp.arange(self.ndim * 2)


# Grid generation functions
def generate_random_cubes(key: jax.random.PRNGKey, size: int, ndim: int, num_cubes: int,
                          cube_width: int) -> jax.Array:
    walls = jnp.zeros((size,) * ndim)
    for _ in range(num_cubes):
        key, subkey = jax.random.split(key)
        loc = jax.random.randint(subkey, (ndim,), 0, size - cube_width + 1)
        subcube = jnp.ones((cube_width,) * ndim)
        walls = jax.lax.dynamic_update_slice(walls, subcube, loc)
    return walls

def generate_empty_grid(size: int, ndim: int) -> GridWorld:
    grid = GridWorld(size=size, ndim=ndim)
    state = GridState(agent_pos=jnp.zeros(ndim, dtype=jnp.int32), walls=jnp.zeros((size,) * ndim))
    return grid, state

def generate_random_grid(key: jax.random.PRNGKey, size: int, ndim: int, num_cubes: int = 3,
                         cube_width: int = 3) -> GridWorld:
    walls = generate_random_cubes(key, size, ndim, num_cubes, cube_width)
    state = GridState(agent_pos=jnp.zeros(ndim, dtype=jnp.int32), walls=walls)
    return GridWorld(size=size, ndim=ndim), state

def generate_fixed_grid(size: int, ndim: int, num_cubes: int = 3, cube_width: int = 3) -> GridWorld:
    walls = generate_random_cubes(jax.random.PRNGKey(0), size, ndim, num_cubes, cube_width)
    state = GridState(agent_pos=jnp.zeros(ndim, dtype=jnp.int32), walls=walls)
    return GridWorld(size=size, ndim=ndim), state


# Helper functions
def create_action_displacements(ndim: int):
    displacements = []
    for i in range(ndim):
        minus = jnp.zeros(ndim, dtype=jnp.int32).at[i].set(-1)
        plus = jnp.zeros(ndim, dtype=jnp.int32).at[i].set(1)
        displacements.append(minus)
        displacements.append(plus)
    return jnp.stack(displacements, axis=0)

def shift_along_axis(a: jax.Array, axis: int, forward: bool = True):

    def _shift_forward():
        read_start = [0] * a.ndim
        read_limit = list(a.shape)
        read_limit[axis] = a.shape[axis] - 1
        old_part = jax.lax.slice(a, read_start, read_limit)

        out = jnp.zeros_like(a)
        write_start = [0] * a.ndim
        write_start[axis] = 1
        out = jax.lax.dynamic_update_slice(out, old_part, start_indices=write_start)
        return out

    def _shift_backward():
        read_start = [0] * a.ndim
        read_start[axis] = 1
        read_limit = list(a.shape)
        old_part = jax.lax.slice(a, read_start, read_limit)
        out = jnp.zeros_like(a)
        write_start = [0] * a.ndim
        out = jax.lax.dynamic_update_slice(out, old_part, start_indices=write_start)
        return out

    return jax.lax.cond(forward, lambda: _shift_forward(), lambda: _shift_backward())

# Set the edge of an array to 1 along a given axis
def add_boundary_wall(a: jnp.ndarray, axis: int, front: bool) -> jnp.ndarray:
    idx = [slice(None)] * a.ndim
    idx[axis] = jax.lax.cond(front, lambda: 0, lambda: -1)
    return a.at[tuple(idx)].set(1)

def generate_cell_indices(size: int, ndim: int) -> jnp.ndarray:
    coords = jnp.stack(jnp.meshgrid(*[jnp.arange(size)] * ndim, indexing='ij'), axis=-1)
    return coords.reshape((-1, ndim))