
#. Mene. For certain debugging purposes

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def traverse_model(model, name = 'main', cur_depth = 0, max_depth = 3):
    print('\t' * cur_depth, name, count_parameters(model))
    if cur_depth < max_depth:
        for name, module in model.named_children():
            traverse_model(module, name, cur_depth + 1, max_depth)


def print_shape(obj, name = None, layer = 0):
    if isinstance(obj, dict):
        print('\t'*layer + '%s(dict):' % name if name else '')
        for k, v in obj.items():
            print_shape(v, k, layer = layer+1)
    elif isinstance(obj, list) or isinstance(obj, tuple):
        print('\t'*layer + '%s(iter):' % name if name else '')
        for v in obj:
            print_shape(v, layer = layer+1)
    else:
        if hasattr(obj, 'shape'):
            print('\t'*layer + "%s: %s %s" % (name if name else '', type(obj), str(obj.shape)))
        else:
            print('\t'*layer + "%s: %s %s" % (name if name else '', type(obj), str(obj)))
