import torch
from torch import Tensor
from torchmetrics import Metric


class Sparsity(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("sparse", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, importance: Tensor) -> None:
        self.sparse += importance.sum() / importance.size(1)
        self.num_samples += importance.size(0)

    def compute(self) -> Tensor:
        return self.sparse / self.num_samples
