import json
import os
from typing import Callable
from collections import defaultdict
import torch
import torch.nn.utils.prune as prune
import torchmetrics
import torchvision
import wandb
from scipy import stats
from torch.nn.utils.prune import _compute_nparams_toprune, _validate_pruning_amount, _validate_pruning_amount_init


class Utilities:
    """Class of utility functions"""

    @staticmethod
    @torch.no_grad()
    def get_model_norm_square(model):
        """Get L2 norm squared of parameter vector. This works for a pruned model as well."""
        squared_norm = 0.
        param_list = ['weight', 'bias']
        for name, module in model.named_modules():
            for param_type in param_list:
                if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
                    param = getattr(module, param_type)
                    squared_norm += torch.norm(param, p=2) ** 2
        return float(squared_norm)

    @staticmethod
    def dump_dict_to_json_wandb(dumpDict, name):
        """Dump some dict to json and upload it"""
        fPath = os.path.join(wandb.run.dir, f'{name}.json')
        with open(fPath, 'w') as fp:
            json.dump(dumpDict, fp)
        wandb.save(fPath)

    @staticmethod
    def get_overloaded_dataset(OriginalDataset):
        class AlteredDatasetWrapper(OriginalDataset):

            def __init__(self, *args, **kwargs):

                super(AlteredDatasetWrapper, self).__init__(*args, **kwargs)

            def __getitem__(self, index):
                # Overload this to collect the class indices once in a vector, which can then be used in the sampler
                image, label = super(AlteredDatasetWrapper, self).__getitem__(index=index)
                return image, label, index

        AlteredDatasetWrapper.__name__ = OriginalDataset.__name__
        return AlteredDatasetWrapper


class WorstClassAccuracy(torchmetrics.Accuracy):
    def __init__(self, **kwargs):
        super().__init__(average=None, **kwargs)

    def compute(self):
        class_accuracies = super().compute()
        return class_accuracies.min()


class LAMPUnstructured(prune.BasePruningMethod):
    r"""Prune (currently unpruned) units in a tensor by zeroing out the ones
    with the appropriate LAMP-Score.

    Args:
        amount (int or float): quantity of parameters to prune.
            If ``float``, should be between 0.0 and 1.0 and represent the
            fraction of parameters to prune. If ``int``, it represents the
            absolute number of parameters to prune.
    """

    PRUNING_TYPE = "unstructured"

    def __init__(self, parameters_to_prune, amount):
        # Check range of validity of pruning amount
        self.parameters_to_prune = parameters_to_prune  # This is a (non-sliced) vector that is passed implicitly
        _validate_pruning_amount_init(amount)
        self.amount = amount

    def compute_mask(self, t, default_mask):
        # In the global case, t is already the global parameter vector, same for the mask
        # In the multiple pruning case, we only get the slice, hence we have to do LAMP as if t was all parameters as a vector
        # BUT: For LAMP it is important to distinguish layers

        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        tensor_list = []
        length_done = 0
        # Modified from https://github.com/jaeho-lee/layer-adaptive-sparsity
        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                p_mask = getattr(module, param_type + '_mask')
                mask_length = int((p_mask == 1).sum())  # Get the number of entries that are still pruneable
            else:
                p_base = getattr(module, param_type)
                mask_length = int(p_base.numel())
            p = t[length_done:length_done + mask_length]
            assert p.numel() == mask_length
            length_done += mask_length

            sorted_scores, sorted_indices = torch.sort(torch.pow(p.flatten(), 2),
                                                       descending=False)  # Get indices to ascending sort
            scores_cumsum_temp = sorted_scores.cumsum(dim=0)
            scores_cumsum = torch.zeros(scores_cumsum_temp.shape, device=p.device)
            scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp) - 1]

            # normalize by cumulative sum
            sorted_scores /= (sorted_scores.sum() - scores_cumsum)
            # tidy up and output
            final_scores = torch.zeros(scores_cumsum.shape, device=p.device)
            final_scores[sorted_indices] = sorted_scores
            tensor_list.append(final_scores)
        score_tensor = torch.cat(tensor_list)
        assert score_tensor.numel() == t.numel()
        mask = default_mask.clone(memory_format=torch.contiguous_format)

        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
            # largest=True --> top k; largest=False --> bottom k
            # Prune the smallest k
            topk = torch.topk(
                score_tensor.view(-1), k=nparams_toprune, largest=False
            )
            # topk will have .indices and .values
            mask.view(-1)[topk.indices] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
        """
        return super(LAMPUnstructured, cls).apply(module, name, amount=amount)

class GradientUnstructured(prune.BasePruningMethod):

    PRUNING_TYPE = "unstructured"

    def __init__(self, parameters_to_prune, amount, gradients, uniform):
        # Check range of validity of pruning amount
        self.parameters_to_prune = parameters_to_prune  # This is a (non-sliced) vector that is passed implicitly
        _validate_pruning_amount_init(amount)
        self.amount = amount
        self.gradients = gradients
        self.uniform = uniform  # If True, apply the pruning uniformly

    def compute_mask(self, t, default_mask):
        # In the global case, t is already the global parameter vector, same for the mask
        # In the multiple pruning case, we only get the slice, hence we have to do LAMP as if t was all parameters as a vector
        # BUT: For LAMP it is important to distinguish layers

        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        tensor_list = []
        length_done = 0

        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                p_mask = getattr(module, param_type + '_mask')
                mask_length = int((p_mask == 1).sum())  # Get the number of entries that are still pruneable
                d_p = self.gradients[(module, param_type)][p_mask.flatten() == 1]    # Get the corresponding gradient entries
            else:
                p_base = getattr(module, param_type)
                mask_length = int(p_base.numel())
                d_p = self.gradients[(module, param_type)]  # Get the corresponding gradient entries
            p = t[length_done:length_done + mask_length]
            assert p.numel() == mask_length
            assert d_p.numel() == p.numel()
            length_done += mask_length

            final_scores = torch.abs(-p*d_p)

            tensor_list.append(final_scores)
        score_tensor = torch.cat(tensor_list)
        assert score_tensor.numel() == t.numel()
        mask = default_mask.clone(memory_format=torch.contiguous_format)
        if not self.uniform:
            # Select globally
            if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
                # largest=True --> top k; largest=False --> bottom k
                # Prune the smallest k
                topk = torch.topk(
                    score_tensor.view(-1), k=nparams_toprune, largest=False
                )
                # topk will have .indices and .values
                mask.view(-1)[topk.indices] = 0
        else:
            # Select per layer
            intermediate_mask_list = []
            for idx in range(len(self.parameters_to_prune)):
                local_score_tensor = tensor_list[idx]
                k = _compute_nparams_toprune(self.amount, local_score_tensor.numel())
                if k != 0:
                    # Prune the smallest k
                    topk = torch.topk(
                        local_score_tensor.view(-1), k=k, largest=False
                    )
                    intermediate_mask = torch.zeros_like(local_score_tensor)
                    intermediate_mask.view(-1)[topk.indices] = 1    # Each 1 here will be a zero in the final mask
                    intermediate_mask_list.append(intermediate_mask)
            final_inverted_mask = torch.cat(intermediate_mask_list)
            mask.view(-1)[final_inverted_mask == 1] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount, parameters_to_prune, gradients):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
        """
        return super(GradientUnstructured, cls).apply(module, name, amount=amount, parameters_to_prune=parameters_to_prune, gradients=gradients)

class UndecayedUnstructured(prune.BasePruningMethod):

    PRUNING_TYPE = "unstructured"

    def __init__(self, parameters_to_prune, amount, gradients, uniform, wd):
        # Check range of validity of pruning amount
        self.parameters_to_prune = parameters_to_prune  # This is a (non-sliced) vector that is passed implicitly
        _validate_pruning_amount_init(amount)
        self.amount = amount
        self.gradients = gradients
        self.uniform = uniform  # If True, apply the pruning uniformly
        self.wd = wd or 0.

    def compute_mask(self, t, default_mask):
        # In the global case, t is already the global parameter vector, same for the mask
        # In the multiple pruning case, we only get the slice, hence we have to do LAMP as if t was all parameters as a vector
        # BUT: For LAMP it is important to distinguish layers

        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        tensor_list = []
        length_done = 0

        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                p_mask = getattr(module, param_type + '_mask')
                mask_length = int((p_mask == 1).sum())  # Get the number of entries that are still pruneable
                d_p = self.gradients[(module, param_type)][p_mask.flatten() == 1]    # Get the corresponding gradient entries
            else:
                p_base = getattr(module, param_type)
                mask_length = int(p_base.numel())
                d_p = self.gradients[(module, param_type)]  # Get the corresponding gradient entries
            p = t[length_done:length_done + mask_length]
            assert p.numel() == mask_length
            assert d_p.numel() == p.numel()
            length_done += mask_length

            final_scores = torch.abs(-p*d_p + self.wd*(p**2))

            tensor_list.append(final_scores)
        score_tensor = torch.cat(tensor_list)
        assert score_tensor.numel() == t.numel()
        mask = default_mask.clone(memory_format=torch.contiguous_format)
        if not self.uniform:
            # Select globally
            if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
                # largest=True --> top k; largest=False --> bottom k
                # Prune the smallest k
                topk = torch.topk(
                    score_tensor.view(-1), k=nparams_toprune, largest=False
                )
                # topk will have .indices and .values
                mask.view(-1)[topk.indices] = 0
        else:
            # Select per layer
            intermediate_mask_list = []
            for idx in range(len(self.parameters_to_prune)):
                local_score_tensor = tensor_list[idx]
                k = _compute_nparams_toprune(self.amount, local_score_tensor.numel())
                if k != 0:
                    # Prune the smallest k
                    topk = torch.topk(
                        local_score_tensor.view(-1), k=k, largest=False
                    )
                    intermediate_mask = torch.zeros_like(local_score_tensor)
                    intermediate_mask.view(-1)[topk.indices] = 1    # Each 1 here will be a zero in the final mask
                    intermediate_mask_list.append(intermediate_mask)
            final_inverted_mask = torch.cat(intermediate_mask_list)
            mask.view(-1)[final_inverted_mask == 1] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount, parameters_to_prune, gradients, wd):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
        """
        return super(UndecayedUnstructured, cls).apply(module, name, amount=amount, parameters_to_prune=parameters_to_prune, gradients=gradients, wd=wd)


class FairnessStatistics:
    """Keeps track of all fairness related statistics.

    General statistics used:
    - CIE: Compression impacted examples
    - CIE-P: CIE that were correctly classified by the dense model
    - CIE-N: CIE that were incorrectly classified by the dense model, but are correctly by the compressed (Note: CIEP + CIEN != CIE)

    Per-class statistics used:
    - CIE()_rel: fraction of CIE/CIEP/CIEN per class relative to the total amount of CIE/CIEP/CIEN
    - CIEP_relpos: fraction of CIEP per class relative to total correctly classified in the dense model (also per class)

    """

    def __init__(self, n_classes, device):
        self.n_classes = n_classes
        self.device = device

        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device=self.device)

        self.classes = torch.arange(0, self.n_classes, 1).to(device=self.device)

        # General class count
        self.class_occurence = torch.zeros_like(self.classes)
        self.dense_correct_occurence = torch.zeros_like(self.classes)
        self.correct_occurence = torch.zeros_like(self.classes)
        self.confusion = {state:{} for state in ['dense', 'sparse']}
        for state in self.confusion.keys():
            for statType in ['TP', 'FN', 'FP', 'TN']:
                self.confusion[state][statType] = torch.zeros_like(self.classes)
        # Compression Impacted Examples
        self.CIE = torch.zeros_like(self.classes)
        self.CIEP = torch.zeros_like(self.classes)
        self.CIEN = torch.zeros_like(self.classes)

        # Dense statistics
        self.dense_loss_per_class = torch.zeros(self.n_classes, device=self.device)
        self.dense_confidence_per_class = torch.zeros(self.n_classes, device=self.device)

    @torch.no_grad()
    def __call__(self, output, output_dense, y_true):
        """
        Gets called for each iteration during a single evaluation pass

        output: Model output of the compressed model
        output_dense: Model output of the dense model
        y_true: true labels
        """

        # General statistics
        occ, cnt = torch.unique(y_true, return_counts=True)
        self.class_occurence[occ] += cnt
        self.dense_correct_occurence += self.get_class_correct(output=output_dense, y_true=y_true)
        self.correct_occurence += self.get_class_correct(output=output, y_true=y_true)
        for state in self.confusion.keys():
            outputType = output if state == 'sparse' else output_dense
            TP, FN, FP, TN = self.get_confusion_matrix(output=outputType, y_true=y_true)
            self.confusion[state]['TP'] += TP
            self.confusion[state]['FN'] += FN
            self.confusion[state]['FP'] += FP
            self.confusion[state]['TN'] += TN

        # Compression Impacted Examples
        CIE, CIEP, CIEN = self.get_CIE(output=output, output_dense=output_dense, y_true=y_true)
        self.CIE += CIE
        self.CIEP += CIEP
        self.CIEN += CIEN

        # Dense statistics
        self.add_dense_class_loss(output_dense=output_dense, y_true=y_true)
        self.add_dense_class_confidence(output_dense=output_dense, y_true=y_true)

    @torch.no_grad()
    def get_results(self):
        logDict = {'n_CIE': self.CIE.sum(),
                   'n_CIEP': self.CIEP.sum(),
                   'n_CIEN': self.CIEN.sum(),
                   'class': {cls: dict() for cls in range(self.n_classes)}
                   }

        # Compression Impact Metrics
        for CIE_type in ['CIE', 'CIEP', 'CIEN']:
            CIE_absolute = getattr(self, CIE_type)
            # Get correct denominator for relative values, e.g. for CIEP we want to know the fraction of CIEP(y) versus total dense correct ones in class
            if CIE_type == 'CIE':
                denom = self.class_occurence
            elif CIE_type == 'CIEP':
                denom = self.dense_correct_occurence
            elif CIE_type == 'CIEN':
                denom = self.class_occurence - self.dense_correct_occurence

            CIE_relative = getattr(self, CIE_type) / denom
            for cls in range(self.n_classes):
                logDict['class'][cls][CIE_type + "_abs"] = CIE_absolute[cls].item()
                logDict['class'][cls][CIE_type + "_rel"] = CIE_relative[cls].item()

        # General metrics
        distribution = self.class_occurence / self.class_occurence.sum()

        # Recall
        dense_class_recall = self.dense_correct_occurence / self.class_occurence
        class_recall = self.correct_occurence / self.class_occurence

        # Accuracy
        dense_class_accuracy = dense_class_recall * distribution
        class_accuracy = class_recall * distribution

        # Differences: Recall and Accuracy
        diffDict = dict()
        for diffType in ['recall', 'accuracy']:
            dense_values = dense_class_recall if diffType == 'recall' else dense_class_accuracy
            values = class_recall if diffType == 'recall' else class_accuracy

            diffDict[f'dense_{diffType}'] = dense_values
            diffDict[diffType] = values

            diffDict[f'abs_dense_class_{diffType}_deviation_from_mean'] = dense_values - dense_values.mean()
            diffDict[f'rel_dense_class_{diffType}_deviation_from_mean'] = diffDict[f'abs_dense_class_{diffType}_deviation_from_mean']/dense_values.mean()

            diffDict[f'abs_{diffType}_diff'] = values - dense_values
            diffDict[f'rel_{diffType}_diff'] = (values - dense_values)/dense_values

            diffDict[f'negative_{diffType}_change'] = torch.clip(dense_values-values, min=0)


        cls_cond_risk = self.dense_loss_per_class * distribution
        cls_cond_risk_rel = cls_cond_risk / cls_cond_risk.sum()
        cls_pred_confidence = self.dense_confidence_per_class / self.class_occurence

        for cls in range(self.n_classes):
            # Class conditioned risk
            logDict['class'][cls]["dense_cond_risk"] = cls_cond_risk[cls].item()
            logDict['class'][cls]["dense_cond_risk_rel"] = cls_cond_risk_rel[cls].item()

            # Distribution
            logDict['class'][cls]["distribution"] = distribution[cls].item()

            # Class prediction confidence
            logDict['class'][cls]["pred_conf"] = cls_pred_confidence[cls].item()

            # Class recall and accuracy
            for name, tensor in diffDict.items():
                logDict['class'][cls][name] = tensor[cls].item()

        # Balanced/Overall accuracy
        logDict['dense_balanced_accuracy'] = dense_class_recall.mean().item()
        logDict['balanced_accuracy'] = class_recall.mean().item()
        logDict['dense_overall_accuracy'] = dense_class_accuracy.sum().item()
        logDict['overall_accuracy'] = class_accuracy.sum().item()

        ### Fairness metrics
        # Unfairness due to Joseph et al.
        dense_unfairness = dense_class_recall.max() - dense_class_recall.min()
        unfairness = class_recall.max() - class_recall.min()
        logDict['dense_unfairness'] = dense_unfairness
        logDict['unfairness'] = unfairness
        logDict['unfairness_change_ratio'] = unfairness/dense_unfairness

        # CVE
        FPR = self.confusion['sparse']['FP'] / (self.confusion['sparse']['FP'] + self.confusion['sparse']['TN'])
        FNR = self.confusion['sparse']['FN'] / (self.confusion['sparse']['FN'] + self.confusion['sparse']['TP'])
        FPR_dense = self.confusion['dense']['FP'] / (self.confusion['dense']['FP'] + self.confusion['dense']['TN'])
        FNR_dense = self.confusion['dense']['FN'] / (self.confusion['dense']['FN'] + self.confusion['dense']['TP'])
        delta_FPR, delta_FNR = (FPR - FPR_dense)/FPR_dense, (FNR - FNR_dense)/FNR_dense
        cve = torch.var(torch.cat((delta_FPR, delta_FNR)))
        logDict['CVE'] = cve

        # New metrics
        for metricType in ['abs', 'rel']:
            neg_recall_change = diffDict[f'negative_recall_change']/dense_class_recall if metricType == 'rel' else diffDict[f'negative_recall_change']
            neg_recall_change_mean, neg_recall_change_std = neg_recall_change.mean(), neg_recall_change.std()
            logDict[f'{metricType}_neg_recall_change_mean'] = neg_recall_change_mean
            logDict[f'{metricType}_neg_recall_change_std'] = neg_recall_change_std
            logDict[f'{metricType}_neg_recall_change_range'] = neg_recall_change.max() - neg_recall_change_std.min()
            logDict[f'{metricType}_neg_recall_change_max'] = neg_recall_change.max()
            logDict[f'{metricType}_neg_recall_change_min'] = neg_recall_change.min()
            logDict[f'{metricType}_neg_recall_change_quartilerange'] = torch.quantile(neg_recall_change, q=0.75) - torch.quantile(neg_recall_change, q=0.25)



        """
        for diffType in ['recall', 'accuracy']:
            
            # Neg/Pos contribution
            logDict[f'{diffType}_diff_pos_contribution'] = torch.sum(diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] >= 0]).item()
            logDict[f'{diffType}_diff_neg_contribution'] = torch.sum(
                diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0]).item()

            # Further Relative-Class-Accuracy/Recall metrics
            logDict[f'abs_{diffType}_diff_neg_range'] = diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0].max() - diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0].min() if diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] < 0].numel() > 0 else 0
            logDict[f'abs_{diffType}_diff_pos_range'] = diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] >= 0].max() - diffDict[f'abs_{diffType}_diff'][
                diffDict[f'abs_{diffType}_diff'] >= 0].min() if diffDict[f'abs_{diffType}_diff'][diffDict[f'abs_{diffType}_diff'] >= 0].numel() > 0 else 0
        """
        return {k: v for k, v in logDict.items() if k != 'class'}, {'class': logDict['class']}

    @torch.no_grad()
    def add_dense_class_confidence(self, output_dense, y_true):
        """Adds the dense confidence per class"""
        probs = torch.nn.functional.softmax(output_dense, dim=1)  # dim: bs x n_classes
        confidence_per_sample = torch.max(probs,
                                          dim=1).values  # Note that this is not the confidence in the right class, but rather in the prediction
        self.dense_confidence_per_class.scatter_add_(0, y_true.flatten(), confidence_per_sample.flatten())

    @torch.no_grad()
    def add_dense_class_loss(self, output_dense, y_true):
        """Adds the dense loss per class"""
        loss_per_sample = self.loss_fn(output_dense, y_true)  # dim: batch_size
        self.dense_loss_per_class.scatter_add_(0, y_true.flatten(), loss_per_sample.flatten())

    @torch.no_grad()
    def get_class_correct(self, output, y_true):
        """Returns the amount of correct predictions per class in one batch"""
        prediction = output.max(dim=1).indices.t()
        cls_bool_mask = (y_true.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        correct_dense_bool_mask = prediction.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))
        return torch.logical_and(cls_bool_mask, correct_dense_bool_mask).sum(dim=1)

    @torch.no_grad()
    def get_confusion_matrix(self, output, y_true):
        """Returns TP, FP, FN, TN"""
        # Compute class occurence
        occ, cnt = torch.unique(y_true, return_counts=True)
        n_cls_occurences = torch.zeros_like(self.classes)
        n_cls_occurences[occ] += cnt

        # Get prediction of network
        prediction = output.max(dim=1).indices.t()

        # Get total population tensor (P+N)
        total_population = torch.zeros_like(self.classes).fill_(len(y_true))

        # Booltensor with n_classes rows indicating at which element in y_true the corresponding class occurs
        true_bool_mask = (y_true.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        # Booltensor with n_classes rows indicating at which element in prediction the corresponding class occurs
        pred_bool_mask = (prediction.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        # Booltensor with n_classes rows indicating at which element the prediction is true
        correct_bool_mask = prediction.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))

        TP = torch.logical_and(true_bool_mask, correct_bool_mask).sum(dim=1)    # TP: Class and Prediction coincide
        FN = n_cls_occurences - TP  # FN: Class occurs but prediction is not the same
        FP = torch.logical_and(pred_bool_mask, ~correct_bool_mask).sum(dim=1)   # FP: Class does not occur but is predicted
        TN = total_population - n_cls_occurences - FP    # TN: Class does not occur and is correctly predicted as not occurring
        return TP, FN, FP, TN

    @torch.no_grad()
    def get_CIE(self, output, output_dense, y_true):
        prediction = output.max(dim=1).indices.t()
        prediction_dense = output_dense.max(dim=1).indices.t()

        cls_bool_mask = (y_true.unsqueeze(0).expand(self.n_classes, len(y_true)) == self.classes.unsqueeze(1))
        CIE_bool = torch.logical_and((~prediction.eq(prediction_dense)).unsqueeze(1).expand(-1, self.n_classes).t(),
                                     cls_bool_mask)
        correct_dense_bool_mask = prediction_dense.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))
        correct_sparse_bool_mask = prediction.eq(y_true).unsqueeze(0).expand(self.n_classes, len(y_true))

        CIE = torch.logical_and((~prediction.eq(prediction_dense)).unsqueeze(1).expand(-1, self.n_classes).t(),
                                cls_bool_mask).sum(dim=1)
        CIEP = torch.logical_and(CIE_bool, correct_dense_bool_mask).sum(dim=1)
        CIEN = torch.logical_and(CIE_bool, correct_sparse_bool_mask).sum(dim=1)
        return CIE, CIEP, CIEN