import numpy as np
import torch.nn as nn


def get_trainable_params(module: nn.Module):
    model_parameters = filter(lambda p: p.requires_grad, module.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params
