import numpy as np
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd


__all__ = ["get_layer_input_dim"]


def get_layer_input_dim(layer: nn.Module):
    # assert isinstance(layer, (nn.Linear, _ConvNd))
    if isinstance(layer, nn.Linear):
        return layer.in_features
    elif isinstance(layer, _ConvNd):
        return layer.in_channels * np.prod(layer.kernel_size)
    else:
        raise ValueError("Layer has to be either nn.Linear or nn.Conv")
