import torch
from ._BaseEvaluator import _Evaluator

__all__ = ['ColorMiou']


class ColorMiou(_Evaluator):
    def __init__(self):
        super().__init__()
        self._buffer_iou_per_class = 0
        self._buffer_fscore_per_class = 0
        self._buffer_cls_per_class = 0

    def add_batch(self, pred, label):
        _miou_pc, _fscore_pc, _cls_pc, _ = calc_color_miou_fscore(pred, label)
        self._buffer_iou_per_class += _miou_pc.cpu()
        self._buffer_fscore_per_class += _fscore_pc.cpu()
        self._buffer_cls_per_class += _cls_pc.cpu()

    @torch.no_grad()
    def compute(self) -> dict:
        miou_pc = self._buffer_iou_per_class
        f_score_pc = self._buffer_fscore_per_class
        cls_pc = self._buffer_cls_per_class
        miou_pc = miou_pc / cls_pc
        miou_pc[torch.isnan(miou_pc)] = 0
        miou = torch.mean(miou_pc).item()
        miou_noBg = torch.mean(miou_pc[:-1]).item()
        f_score_pc = f_score_pc / cls_pc
        f_score_pc[torch.isnan(f_score_pc)] = 0
        f_score = torch.mean(f_score_pc).item()
        f_score_noBg = torch.mean(f_score_pc[:-1]).item()
        self._buffer_iou_per_class = 0
        self._buffer_fscore_per_class = 0
        self._buffer_cls_per_class = 0

        return {"miou": miou,
                "f_score": f_score,
                "f_score_noBg": f_score_noBg,
                "miou_noBg": miou_noBg}


def _batch_miou_fscore(output, target, nclass, beta2=0.3):
    """batch mIoU and Fscore"""
    # output: [BF, C, H, W],
    # target: [BF, H, W]
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = torch.argmax(output, 1) + 1
    target = target.float() + 1
    # pdb.set_trace()
    predict = predict.float() * (target > 0).float()  # [BF, H, W]
    intersection = predict * (predict == target).float()  # [BF, H, W]
    # areas of intersection and union
    # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
    # batch_size = target.shape[0] // T
    cls_count = torch.zeros(nclass, device=output.device, dtype=float)
    ious = torch.zeros(nclass, device=output.device, dtype=float)
    fscores = torch.zeros(nclass, device=output.device, dtype=float)

    # vid_miou_list = torch.zeros(target.shape[0]).float()
    vid_miou_list = []
    for i in range(target.shape[0]):
        area_inter = torch.histc(intersection[i], bins=nbins, min=mini, max=maxi)  # TP
        area_pred = torch.histc(predict[i], bins=nbins, min=mini, max=maxi)  # TP + FP
        area_lab = torch.histc(target[i], bins=nbins, min=mini, max=maxi)  # TP + FN
        area_union = area_pred + area_lab - area_inter
        assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
        iou = 1.0 * area_inter.float() / (2.220446049250313e-16 + area_union.float())
        # iou[torch.isnan(iou)] = 1.
        ious += iou
        cls_count[torch.nonzero(area_union).squeeze(-1)] += 1

        precision = area_inter / area_pred
        recall = area_inter / area_lab
        fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall)
        fscore[torch.isnan(fscore)] = 0.
        fscores += fscore

        vid_miou_list.append(torch.sum(iou) / (torch.sum(iou != 0).float()))

    return ious, fscores, cls_count, vid_miou_list


def calc_color_miou_fscore(pred, target):
    r"""
    J measure
        param:
            pred: size [BF x C x H x W], C is category number including background
            target: size [BF x H x W]
    """
    nclass = pred.shape[1]
    pred = torch.softmax(pred, dim=1)  # [BF, C, H, W]
    # miou, fscore, cls_count = _batch_miou_fscore(pred, target, nclass, T)
    miou, fscore, cls_count, vid_miou_list = _batch_miou_fscore(pred, target, nclass)
    return miou, fscore, cls_count, vid_miou_list
