# gridworld_jax.py
from __future__ import annotations

from typing import Any, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
from flax import struct

# ... [Helpers _parse_coords_list and random_simplex_over_support remain the same] ...

def _parse_coords_list(coords: Optional[Sequence[Tuple[int, int]]]) -> Tuple[Tuple[int, int], ...]:
    if coords is None:
        return tuple()
    return tuple((int(r), int(c)) for (r, c) in coords)

def random_simplex_over_support(key: jax.Array, support_mask: jax.Array) -> jax.Array:
    x = jax.random.exponential(key, shape=support_mask.shape)
    x = jnp.where(support_mask, x, 0.0)
    z = jnp.sum(x)
    probs = jnp.where(z > 0, x / z, support_mask.astype(jnp.float32) / jnp.sum(support_mask))
    return probs

@struct.dataclass
class GridWorldParams:
    rows: int
    cols: int
    walls_mask: jax.Array
    terminal_mask: jax.Array
    reward_grid: jax.Array
    default_reward: float
    success_probability: float
    mu0: jax.Array
    P: jax.Array
    R: jax.Array

@struct.dataclass
class GridWorldState:
    s: jax.Array

class GridWorldJAX:
    A = 4
    def __init__(self, rows: int, cols: int):
        self.rows = int(rows)
        self.cols = int(cols)

    def build_mappings(self, walls: Sequence[Tuple[int, int]]):
        walls_set = set((int(r), int(c)) for r, c in walls)
        coord2idx = -jnp.ones((self.rows, self.cols), dtype=jnp.int32)
        coords = []
        idx = 0
        for r in range(self.rows):
            for c in range(self.cols):
                if (r, c) in walls_set: continue
                coord2idx = coord2idx.at[r, c].set(idx)
                coords.append((r, c))
                idx += 1
        return coord2idx, jnp.array(coords, dtype=jnp.int32)

    def make_params(
        self,
        key: jax.Array,
        start_coord: Tuple[int, int] = (0, 0),
        terminal_states: Optional[Sequence[Tuple[int, int]]] = None,
        success_probability: float = 1.0,
        reward_at: Optional[Sequence[Tuple[Tuple[int, int], float]]] = None,
        walls: Optional[Sequence[Tuple[int, int]]] = ((1, 1), (2, 2)),
        default_reward: float = 0.0,
        common: Optional[jax.Array] = None,
        epsilon_p: float = 0.0,
        mu0: Optional[jax.Array] = None,
    ) -> GridWorldParams:
        walls = _parse_coords_list(walls)
        terminal_states = _parse_coords_list(terminal_states)

        # Logic for target rewards
        if reward_at is None:
            # Every terminal state becomes a target with +1 reward
            reward_at = [((r, c), 1.0) for (r, c) in terminal_states]
            if not reward_at:
                reward_at = [((self.rows - 1, self.cols - 1), 1.0)]
        
        reward_at = [((int(r), int(c)), float(v)) for ((r, c), v) in reward_at]
        coord2idx, idx2coord = self.build_mappings(walls)
        S = int(idx2coord.shape[0])

        walls_mask = jnp.zeros((self.rows, self.cols), dtype=bool)
        for (r, c) in walls: walls_mask = walls_mask.at[r, c].set(True)

        terminal_mask = jnp.zeros((self.rows, self.cols), dtype=bool)
        for (r, c) in terminal_states: terminal_mask = terminal_mask.at[r, c].set(True)

        reward_grid = jnp.full((self.rows, self.cols), float(default_reward), dtype=jnp.float32)
        for (r, c), v in reward_at:
            reward_grid = reward_grid.at[r, c].set(jnp.float32(v))

        # Initial distribution
        start_idx = coord2idx[int(start_coord[0]), int(start_coord[1])]
        if mu0 is None:
            mu0 = jnp.zeros((S,), dtype=jnp.float32).at[start_idx].set(1.0)
        else:
            mu0 = mu0 / jnp.sum(mu0)

        P_base = self._build_P_base(coord2idx, idx2coord, walls_mask, success_probability)
        
        # Random individual kernel mixing
        Individual = self._build_individual_kernel(key, P_base)
        if common is not None:
            P = (1.0 - float(epsilon_p)) * common + float(epsilon_p) * Individual
            P = P / jnp.sum(P, axis=-1, keepdims=True)
        else:
            P = P_base

        # PRECOMPUTE REWARD: R[s, a] = reward of being in state s
        # This ensures you get +1 when entering (as the new current state) 
        # and +1 for every step you spend/stay there.
        rs = reward_grid[idx2coord[:, 0], idx2coord[:, 1]]
        R = jnp.tile(rs[:, None], (1, self.A))

        return GridWorldParams(
            rows=self.rows, cols=self.cols, walls_mask=walls_mask,
            terminal_mask=terminal_mask, reward_grid=reward_grid,
            default_reward=float(default_reward),
            success_probability=float(success_probability),
            mu0=mu0, P=P.astype(jnp.float32), R=R.astype(jnp.float32),
        )

    def _build_P_base(self, coord2idx, idx2coord, walls_mask, p_succ) -> jax.Array:
        S, A = idx2coord.shape[0], self.A
        dr = jnp.array([0, 0, 1, -1], dtype=jnp.int32)
        dc = jnp.array([-1, 1, 0, 0], dtype=jnp.int32)

        def is_valid(r, c):
            in_bounds = (r >= 0) & (r < self.rows) & (c >= 0) & (c < self.cols)
            return jnp.where(in_bounds, ~walls_mask[r, c], False)

        def per_state(s):
            r, c = idx2coord[s, 0], idx2coord[s, 1]
            rr, cc = r + dr, c + dc
            ok = jax.vmap(is_valid)(rr, cc)
            nxt = jax.vmap(lambda r_n, c_n, valid: jnp.where(valid, coord2idx[r_n, c_n], s))(rr, cc, ok)

            def dist_for_action(a):
                n_valid = jnp.sum(ok.astype(jnp.int32))
                def valid_choice():
                    mass_chosen = p_succ + (1.0 - p_succ) * (n_valid == 1)
                    rem = (1.0 - mass_chosen) / jnp.maximum(n_valid - 1, 1)
                    d = jnp.zeros((S,)).at[nxt[a]].set(mass_chosen)
                    d = jax.lax.fori_loop(0, A, lambda b, di: di.at[nxt[b]].add(
                        jnp.where((b != a) & ok[b] & (n_valid > 1), rem, 0.0)), d)
                    return d / jnp.sum(d)
                return jax.lax.cond(ok[a], valid_choice, lambda: jnp.zeros((S,)).at[s].set(1.0))
            return jax.vmap(dist_for_action)(jnp.arange(A))
        return jax.vmap(per_state)(jnp.arange(S))

    def _build_individual_kernel(self, key, P_base):
        S, A, _ = P_base.shape
        def per_s(s):
            def per_a(a):
                support = P_base[s, a] > 0
                k = jax.random.fold_in(key, s * A + a)
                return random_simplex_over_support(k, support)
            return jax.vmap(per_a)(jnp.arange(A))
        return jax.vmap(per_s)(jnp.arange(S)).astype(jnp.float32)

    def reset(self, key, params):
        s0 = jax.random.choice(key, a=jnp.arange(params.mu0.shape[0]), p=params.mu0).astype(jnp.int32)
        return s0, GridWorldState(s=s0)

    def step(self, key, state, action, params):
        s, a = state.s.astype(jnp.int32), action.astype(jnp.int32)
        r = params.R[s, a] # Reward is tied to being in state 's'
        p = params.P[s, a]
        s2 = jax.random.choice(key, a=jnp.arange(p.shape[0]), p=p).astype(jnp.int32)
        return s2, GridWorldState(s=s2), r, jnp.array(False), {}