import numpy as np
from src.gridworld.env import LongCorridor

def test_obs_shape_and_onehot():
    env = LongCorridor(Lx=20, Ly=5, n_colors=4, obs_size=5, seed=1)
    x,y,h = env.reset()
    obs = env.egocentric_obs(x,y,h)
    C = env.n_colors + 1
    assert obs.ndim == 1
    assert obs.size == env.obs_size*env.obs_size*C
    # Per-cell one-hot check
    obs2 = obs.reshape(env.obs_size*env.obs_size, C)
    row_sums = obs2.sum(axis=1)
    assert np.allclose(row_sums, 1.0)
