import torch
from src.models.rnn import SeqPredGRU

def test_model_forward_shapes():
    B, k, obs_dim, feat_dim, H = 4, 3, 175, 10, 64
    model = SeqPredGRU(obs_dim=obs_dim, feat_dim=feat_dim, hidden=H, k=k)
    o0 = torch.randn(B, obs_dim)
    feats = torch.randn(B, k, feat_dim)
    preds, outs = model(o0, feats)
    assert preds.shape == (B, k, obs_dim)
    assert outs.shape == (B, k, H)
