import numpy as np
import torch
from torch import Tensor
from torchmetrics import Metric


class SymmetricSegmentationIOU(Metric):
    def __init__(self, patch_size, threshold=0.5, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.patches_in_row = 224 // self.patch_size
        self.threshold = threshold
        self.add_state("total_iou", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="sum")

    def _preprocess(self, tensor: Tensor, is_segmentation: bool = False) -> np.ndarray:
        tensor = tensor.squeeze(1).cpu().numpy()

        if is_segmentation:
            # Aggregate segmentation patches
            tensor = (
                tensor.reshape(
                    -1, self.patches_in_row, self.patch_size, self.patches_in_row, self.patch_size
                ).mean(axis=(2, 4))
                > 0
            )
        else:
            # Threshold importance directly
            tensor = (tensor > self.threshold).reshape(
                -1, self.patches_in_row, self.patches_in_row
            )

        return tensor.astype(bool)

    def _iou(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
        intersection = np.logical_and(mask1, mask2).sum(axis=(1, 2))
        union = np.logical_or(mask1, mask2).sum(axis=(1, 2))
        valid = union != 0
        return intersection[valid] / union[valid]

    def update(self, importance: Tensor, segmentation: Tensor) -> None:
        imp_mask = self._preprocess(importance, is_segmentation=False)
        seg_mask = self._preprocess(segmentation, is_segmentation=True)

        # Foreground IoU
        iou_fg = self._iou(imp_mask, seg_mask)

        # Background IoU
        iou_bg = self._iou(~imp_mask, ~seg_mask)

        # Symmetric IoU = average
        iou = (iou_fg + iou_bg) / 2

        self.total_iou += iou.sum()
        self.num_samples += len(iou)

    def compute(self) -> Tensor:
        if self.num_samples == 0:
            return torch.tensor(0.0)
        return self.total_iou / self.num_samples