import time
import torch

from . import measure
from ..p_utils import get_layer_metric_array

@measure('params', copy_net=False, mode='param')
def get_param_count_array(net, inputs, targets, mode, loss_fn=None, split_data=1):
    count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
    return(count)
