import torch


class AssertShape(torch.nn.Module):

    def __init__(self, *shape):
        super().__init__()
        self._shape = shape

    def forward(self, X):
        assert X.shape[1:] == self._shape, str(X.size())
        return X