"""Evaluation Metrics for Semantic Segmentation"""

"""Code taken from https://github.com/Tramac/mobilenetv3-segmentation/blob/master/core/utils/metric.py """

import numpy as np
import torch

__all__ = [
    "SegmentationMetric",
    "batch_pix_accuracy",
    "batch_intersection_union",
    "intersectionAndUnion",
    "hist_info",
    "compute_score",
]


class SegmentationMetric(object):
    """Computes pixAcc and mIoU metric scores"""

    def __init__(self, nclass):
        super(SegmentationMetric, self).__init__()
        self.nclass = nclass
        self.reset()

    def update(self, preds, labels, masks):
        """Updates the internal evaluation result.
        Parameters
        ----------
        labels : 'NumpyArray' or list of `NumpyArray`
            The labels of the data.
        preds : 'NumpyArray' or list of `NumpyArray`
            Predicted values.
        """

        def evaluate_worker(self, pred, label, mask):
            inter, union, cls_visited = batch_intersection_union(
                pred, label, self.nclass, mask
            )

            if self.total_inter.device != inter.device:
                self.total_inter = self.total_inter.to(inter.device)
                self.total_union = self.total_union.to(union.device)
            self.total_inter += inter
            self.total_union += union

            self.total_visited.update(cls_visited)

        if isinstance(preds, torch.Tensor):
            evaluate_worker(self, preds, labels, masks)
        elif isinstance(preds, (list, tuple)):
            for pred, label, mask in zip(preds, labels, masks):
                evaluate_worker(self, pred, label, mask)

    def get(self):
        """Gets the current evaluation result.
        Returns
        -------
        metrics : tuple of float
            pixAcc and mIoU
        """
        tens_totVisited = torch.tensor(list(self.total_visited)).long()
        IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
        mIoU = IoU[tens_totVisited].mean().item()
        return mIoU

    def reset(self):
        """Resets the internal evaluation result to initial state."""
        self.total_inter = torch.zeros(self.nclass)
        self.total_union = torch.zeros(self.nclass)
        self.total_correct = 0
        self.total_label = 0
        self.total_visited = set()


# pytorch version
def batch_pix_accuracy(output, target, mask):
    """PixAcc"""
    predict = output
    predict = predict[mask]
    target = target[mask]

    pixel_labeled = torch.sum(target > 0).item()
    pixel_correct = torch.sum((predict == target) * (target > 0)).item()
    assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
    return pixel_correct, pixel_labeled


def batch_intersection_union(output, target, nclass, mask):
    """mIoU"""
    # inputs are numpy array, output 4D, target 3D
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = output
    predict = (predict[mask] + 1).float()
    target = (target[mask] + 1).float()
    cls_visited = torch.unique(target - 1).tolist()

    predict = predict.float() * (target > -0).float()
    intersection = predict * (predict == target).float()
    # areas of intersection and union
    # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
    area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
    area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
    area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
    area_union = area_pred + area_lab - area_inter

    assert (
        torch.sum(area_inter > area_union).item() == 0
    ), "Intersection area should be smaller than Union area"
    return area_inter.float(), area_union.float(), cls_visited


def intersectionAndUnion(imPred, imLab, numClass):
    """
    This function takes the prediction and label of a single image,
    returns intersection and union areas for each class
    To compute over many images do:
    for i in range(Nimages):
        (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i])
    IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1)
    """
    # Remove classes from unlabeled pixels in gt image.
    # We should not penalize detections in unlabeled portions of the image.
    imPred = imPred * (imLab >= 0)

    # Compute area intersection:
    intersection = imPred * (imPred == imLab)
    (area_intersection, _) = np.histogram(
        intersection, bins=numClass, range=(1, numClass)
    )

    # Compute area union:
    (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
    (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
    area_union = area_pred + area_lab - area_intersection
    return (area_intersection, area_union)


def hist_info(pred, label, num_cls):
    assert pred.shape == label.shape
    k = (label >= 0) & (label < num_cls)
    labeled = np.sum(k)
    correct = np.sum((pred[k] == label[k]))

    return (
        np.bincount(
            num_cls * label[k].astype(int) + pred[k], minlength=num_cls**2
        ).reshape(num_cls, num_cls),
        labeled,
        correct,
    )


def compute_score(hist, correct, labeled):
    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
    mean_IU = np.nanmean(iu)
    mean_IU_no_back = np.nanmean(iu[1:])
    freq = hist.sum(1) / hist.sum()
    freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
    mean_pixel_acc = correct / labeled

    return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
