import torch


def get_device(allow_mps=True, allow_gpu=True):
    device_name = (
        "mps"
        if torch.backends.mps.is_available() and allow_mps
        else ("cuda" if torch.cuda.is_available() and allow_gpu else "cpu")
    )
    return torch.device(device_name)


def count_module_parameters(module):
    c = 0
    for p in module.parameters():
        c += p.numel()
    return c
