import torch


class Reshape(torch.nn.Module):

    def __init__(self, *shape):
        super().__init__()
        self._shape = shape

    def forward(self, X):
        return X.view(X.size(0), *self._shape)