import torch
from ._BaseEvaluator import _Evaluator

__all__ = ['MaskIoU']

class MaskIoU(_Evaluator):
    def __init__(self):
        super().__init__()
        self._buffer_iou = []

    def add_batch(self, pred, label) -> None:

        iou = mask_iou(pred,label)
        self._buffer_iou.append(iou.cpu())

    @torch.no_grad()
    def compute(self) -> dict:
        iou = torch.cat(self._buffer_iou,dim=0)

        res = iou.mean()

        self._buffer_iou = []

        return {"mask_iou": res}


def mask_iou(pred, target, eps=1e-7):
    """
    :param pred: [N x H x W]
    :param target: [N x H x W]
    :param eps: 1e-7 or so
    :return: iou: size [B,1]
    """
    assert len(
        pred.shape) == 3 and pred.shape == target.shape, f"pred shape {pred.shape} and target shape {target.shape}"

    num_pixels = pred.size(-1) * pred.size(-2)
    no_obj_flag = (target.sum((1, 2)) == 0)

    temp_pred = torch.sigmoid(pred)
    pred = (temp_pred > 0.5).int()
    inter = (pred * target).sum((1, 2))
    union = torch.max(pred, target).sum((1, 2))

    inter_no_obj = ((1 - target) * (1 - pred)).sum((1, 2))
    inter[no_obj_flag] = inter_no_obj[no_obj_flag]
    union[no_obj_flag] = num_pixels

    iou = inter / (union + eps)

    return iou
