import torch

from model.alternate_attention import NPT

torch.manual_seed(42)

# Input
N, D = 2, 3
X = torch.randn((N, D))

# Model kwargs
E = 6
kwargs = {
    "num_features": D,
    "hidden_dim": E,
    "num_encoder_layers": 2,
    "num_heads": 2,
    "eps_layer_norm": 1e-5,
    "p_dropout": 0.1
}
npt = NPT(**kwargs)

output = npt(X)

print(X.shape)
print(output.shape)
