from torch import nn

def split_model(model):
    modules = []
    for m in model.children():
        if isinstance(m, (nn.Sequential,)):
            modules += split_model(m)
        else:
            modules.append(m)
    return modules

def time_consumption_per_layer(network = 'resnet20'):
    if network == 'resnet20':
        return [546, 508, 482, 416, 365, 318, 257, 210, 160, 102, 57, 11]
        # return [12-i for i in range(12)]
    elif network == 'lenet5':
        return [5-i for i in range(5)]
    elif network == 'ds_rn50':
        return [40-i for i in range(32)]
    elif network == 'ntk_cnn':
        return [15-i for i in range(15)]
    elif network == 'cnn':
        return [5-i for i in range(5)]
    else:
        raise NotImplementedError