import numpy as np
import torch
from torch import Tensor
from torchmetrics import Metric


class SegmentationAbsoluteDiff(Metric):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.patches_in_row = 224 // self.patch_size
        self.add_state("absolute_diff", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, importance: Tensor, segmentation: Tensor) -> None:
        segmentation = segmentation.squeeze(1).cpu().numpy()
        importance = importance.reshape(-1, self.patches_in_row, self.patches_in_row).cpu().numpy()
        importance = np.nan_to_num(importance)
        segmentation = segmentation.reshape(
            -1, self.patches_in_row, self.patch_size, self.patches_in_row, self.patch_size
        ).mean(axis=(2, 4))

        diff = np.abs(importance - segmentation)

        self.absolute_diff += diff.sum() / self.patches_in_row / self.patches_in_row
        self.num_samples += len(diff)

    def compute(self) -> Tensor:
        """Computes the final IoU score."""
        if self.num_samples == 0:
            return torch.tensor(0.0)
        return self.absolute_diff / self.num_samples

