from Layers import layers


def masks(module):
    r"""Returns an iterator over modules masks, yielding the mask.
    """
    for name, buf in module.named_buffers():
        if "mask" in name:
            yield buf


def trainable(module):
    r"""Returns boolean whether a module is trainable.
    """
    # return not isinstance(module, (layers.Identity1d, layers.Identity2d))
    return True


def prunable(module, batchnorm, residual, linear, shortcut):
    r"""Returns boolean whether a module is prunable.
    """
    # if isinstance(module, layers.Conv2d) and layers.Conv2d.kernel_size != (1, 1):
    # isprunable = isinstance(module, (layers.Conv2d))
    isprunable = isinstance(module, layers.Conv2d) and (shortcut or module.kernel_size != (1, 1))
    if batchnorm:
        isprunable |= isinstance(module, (layers.BatchNorm2d))
    if linear:
        isprunable |= isinstance(module, (layers.Linear))
    return isprunable


def parameters(model):
    r"""Returns an iterator over models trainable parameters, yielding just the
    parameter tensor.
    """
    for module in filter(lambda p: trainable(p), model.modules()):
        for param in module.parameters(recurse=False):
            yield param


def pointwise(module):
    return isinstance(module, layers.Conv2d) and module.kernel_size == (1, 1)


def masked_parameters(model, bias=False, batchnorm=False, residual=False, linear=True, shortcut=True):
    r"""Returns an iterator over models prunable parameters, yielding both the
    mask and parameter tensors.
    """
    for module in filter(lambda p: prunable(p, batchnorm, residual, linear, shortcut), model.modules()):
        for mask, param in zip(masks(module), module.parameters(recurse=False)):
            if param is not module.bias or bias is True:
                yield mask, param


def pointwise_parameters(model, bias=False, batchnorm=False, residual=False, linear=True, shortcut=True):
    r"""Returns an iterator over models prunable parameters, yielding both the
    mask and parameter tensors.
    Keep all the bias, BN, residual, linear by default
    """
    for module in filter(lambda p: pointwise(p), model.modules()):
        for mask, (name, param) in zip(masks(module), module.named_parameters(recurse=False)):
            if param is not module.bias or bias is True:
                yield mask, param


def all_parameters(model, bias=False, batchnorm=False, residual=False, linear=True, shortcut=True):
    for module in model.modules():
        for mask, (name, param) in zip(masks(module), module.named_parameters(recurse=False)):
            yield mask, param