from src.gridworld.env import LongCorridor
from src.data.dataset import SeqPredDataset

def test_dataset_shapes():
    env = LongCorridor(Lx=20, Ly=5, n_colors=4, obs_size=5, seed=2)
    ds = SeqPredDataset(env, n_traj=2, T=12, k=3, seed=0)
    o0, feats, targets = ds[0]
    assert feats.shape[1] == 10
    assert feats.shape[0] == 3
    assert targets.shape[0] == 3
    assert targets.shape[1] == ds.obs_dim
