from typing import List, Union

from torch import Tensor

__all__ = [
    "tensor_sparsity", 
    "tensor_density"
]

def tensor_sparsity(x: Tensor, dim: Union[None, int] = None) -> float:
    """
    Args:
        x: tensor 
        dim: dimension to reduce
    Returns:
        sparsity: float in [0, 1]
    """
    if dim is None:
        return x.eq(0).sum().item() / x.numel()
    else:
        assert dim < x.ndim
        return x.eq(0).all(dim=dim).sum().item() / x.shape[dim]

def tensor_density(x: Tensor, dim: Union[None, int] = None) -> float:
    """
    Args:
        x: tensor 
        dim: dimension to reduce
    Returns:
        density: float in [0, 1]
    """
    return 1 - tensor_sparsity(x, dim)
