import torch
import cv2
import numpy as np

__all__ = ['SegmentationMetric']


class SegmentationMetric(object):
    def __init__(self, numClass):
        self.numClass = numClass
        self.confusionMatrix = torch.zeros((self.numClass,) * 2)

    def pixelAccuracy(self):
        acc = torch.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
        return acc

    def classPixelAccuracy(self):
        classAcc = torch.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
        return classAcc

    def meanPixelAccuracy(self):
        classAcc = self.classPixelAccuracy()
        meanAcc = classAcc[classAcc < float('inf')].mean()
        return meanAcc

    def meanCODSODConfusionAccuracy(self):
        acc1 = self.confusionMatrix[1][2] / self.confusionMatrix.sum()
        acc2 = self.confusionMatrix[2][1] / self.confusionMatrix.sum()
        return (acc1 + acc2) / 2

    def meanCODSODConfusionAccuracy2(self):
        print(self.confusionMatrix.sum(axis=0))
        acc1 = self.confusionMatrix[1][2] / self.confusionMatrix.sum(axis=0)[2]
        acc2 = self.confusionMatrix[2][1] / self.confusionMatrix.sum(axis=0)[1]
        return (acc1 + acc2) / 2

    def meanCODSODConfusionAccuracy3(self):
        print(self.confusionMatrix.sum(axis=0))
        if self.confusionMatrix.sum(axis=1)[1] == 0:
            acc1 = 0
        else:
            acc1 = self.confusionMatrix[1][2] / self.confusionMatrix.sum(axis=1)[1]
        if self.confusionMatrix.sum(axis=1)[2] == 0:
            acc2 = 0
        else:
            acc2 = self.confusionMatrix[2][1] / self.confusionMatrix.sum(axis=1)[2]
        return (acc1 + acc2) / 2

    def printColSum(self):
        print(torch.sum(self.confusionMatrix, axis=0))

    def IntersectionOverUnion(self):
        intersection = torch.diag(self.confusionMatrix)
        union = torch.sum(self.confusionMatrix, axis=1) + torch.sum(self.confusionMatrix, axis=0) - torch.diag(self.confusionMatrix)
        IoU = intersection / union
        return IoU

    def meanIntersectionOverUnion(self):
        IoU = self.IntersectionOverUnion()
        mIoU = IoU[IoU < float('inf')].mean()
        return mIoU

    def genConfusionMatrix(self, imgPredict, imgLabel, ignore_labels):
        mask = (imgLabel >= 0) & (imgLabel < self.numClass)
        for IgLabel in ignore_labels:
            mask &= (imgLabel != IgLabel)
        label = self.numClass * imgLabel[mask] + imgPredict[mask]
        count = torch.bincount(label, minlength=self.numClass ** 2)
        confusionMatrix = count.view(self.numClass, self.numClass)
        return confusionMatrix

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = torch.sum(self.confusion_matrix, axis=1) / torch.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
                torch.sum(self.confusion_matrix, axis=1) + torch.sum(self.confusion_matrix, axis=0) -
                torch.diag(self.confusion_matrix))
        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU

    def addBatch(self, imgPredict, imgLabel, ignore_labels):
        assert imgPredict.shape == imgLabel.shape
        self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel, ignore_labels)
        return self.confusionMatrix

    def reset(self):
        self.confusionMatrix = torch.zeros((self.numClass, self.numClass))


if __name__ == '__main__':
    imgPredict = torch.tensor([[0, 1, 2], [2, 1, 1]]).long()
    imgLabel = torch.tensor([[0, 1, 255], [1, 1, 2]]).long()

    ignore_labels = [255]
    metric = SegmentationMetric(3)

    hist = metric.addBatch(imgPredict, imgLabel, ignore_labels)
    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    IoU = metric.IntersectionOverUnion()
    mIoU = metric.meanIntersectionOverUnion()
    print('hist is :\n', hist)
    print('PA is : %f' % pa)
    print('cPA is :', cpa)
    print('mPA is : %f' % mpa)
    print('IoU is : ', IoU)
    print('mIoU is : ', mIoU)

    imgPredict = torch.tensor([[0, 1, 2], [2, 1, 1]]).long()
    imgLabel = torch.tensor([[0, 1, 2], [2, 1, 1]]).long()

    hist = metric.addBatch(imgPredict, imgLabel, ignore_labels)
    pa = metric.pixelAccuracy()
    cpa = metric.classPixelAccuracy()
    mpa = metric.meanPixelAccuracy()
    IoU = metric.IntersectionOverUnion()
    mIoU = metric.meanIntersectionOverUnion()
    print('hist is :\n', hist)
    print('PA is : %f' % pa)
    print('cPA is :', cpa)
    print('mPA is : %f' % mpa)
    print('IoU is : ', IoU)
    print('mIoU is : ', mIoU)