from typing import Any, Dict
import numpy as np
import torch

class MeanIoU():
    def __init__(self,
                 num_classes: int,
                 ignore_label: int,
                 output_tensor: str = 'outputs',
                 target_tensor: str = 'targets',
                 name: str = 'iou') -> None:
        self.num_classes = num_classes
        self.ignore_label = ignore_label
        self.name = name
        self.output_tensor = output_tensor
        self.target_tensor = target_tensor

    def _before_epoch(self) -> None:
        self.total_seen = np.zeros(self.num_classes)
        self.total_correct = np.zeros(self.num_classes)
        self.total_positive = np.zeros(self.num_classes)

    def _after_step(self, output_dict: Dict[str, Any]) -> None:
        outputs = output_dict[self.output_tensor]
        targets = output_dict[self.target_tensor]
        outputs = outputs[targets != self.ignore_label]
        targets = targets[targets != self.ignore_label]
        if type(outputs) != np.ndarray:
            for i in range(self.num_classes):
                self.total_seen[i] += torch.sum(targets == i).item()
                self.total_correct[i] += torch.sum(
                    (targets == i) & (outputs == targets)).item()
                self.total_positive[i] += torch.sum(outputs == i).item()
        else:
            for i in range(self.num_classes):
                self.total_seen[i] += np.sum(targets == i)
                self.total_correct[i] += np.sum((targets == i) & (outputs == targets))
                self.total_positive[i] += np.sum(outputs == i)

    def _after_step_within_predregion(self, output_dict: Dict[str, Any]) -> None:
        outputs = output_dict[self.output_tensor]
        targets = output_dict[self.target_tensor]
        targets = targets[outputs != self.ignore_label]
        outputs = outputs[outputs != self.ignore_label]
        if type(outputs) != np.ndarray:
            for i in range(self.num_classes):
                self.total_seen[i] += torch.sum(targets == i).item()
                self.total_correct[i] += torch.sum(
                    (targets == i) & (outputs == targets)).item()
                self.total_positive[i] += torch.sum(outputs == i).item()
        else:
            for i in range(self.num_classes):
                self.total_seen[i] += np.sum(targets == i)
                self.total_correct[i] += np.sum((targets == i) & (outputs == targets))
                self.total_positive[i] += np.sum(outputs == i)

    def _after_epoch(self, ignore_label_list=None) -> None:
        ious = []

        for i in range(self.num_classes):
            if ignore_label_list is not None and i in ignore_label_list:
                continue
            if self.total_seen[i] == 0:
                ious.append(1)
            else:
                cur_iou = self.total_correct[i] / (self.total_seen[i] + self.total_positive[i] - self.total_correct[i])
                ious.append(cur_iou)

        # 0.xx to 100%
        ious = [num * 100 for num in ious]
        return ious

    def _after_epoch_ipr(self):
        ious = []
        precisions = []
        recalls = []

        for i in range(self.num_classes):
            if self.total_seen[i] == 0:
                ious.append(1)
                precisions.append(1)
                recalls.append(1)
            else:
                cur_iou = self.total_correct[i] / (self.total_seen[i] + self.total_positive[i] - self.total_correct[i])
                cur_prec = self.total_correct[i] / (self.total_positive[i])
                cur_recall = self.total_correct[i] / (self.total_seen[i])
                ious.append(cur_iou)
                precisions.append(cur_prec)
                recalls.append(cur_recall)

        # 0.xx to 100%
        ious = [num * 100 for num in ious]
        precisions = [num * 100 for num in precisions]
        recalls = [num * 100 for num in recalls]

        return (ious, precisions, recalls)