import numpy as np
import torch


class HistogramBinning():
    """
        Histogram Binning as a calibration method. The bins are divided into equal lengths.

        The class contains two methods:
            - fit(probs, true), that should be used with validation data to train the calibration model.
            - predict(logits), this method is used to calibrate the confidences.
        """

    def __init__(self, N_bin=10):
        self.bin_size = 1. / N_bin 
        self.bin_value = []
        self.upper_bounds = np.arange(self.bin_size, 1 + self.bin_size, self.bin_size)

    def _get_conf(self, conf_thresh_lower, conf_thresh_upper, probs, true):
        """
        Inner method to calculate optimal confidence for certain probability range

        Params:
            - conf_thresh_lower (float): start of the interval (not included)
            - conf_thresh_upper (float): end of the interval (included)
            - probs : list of probabilities.
            - true : list with true labels, where 1 is positive class and 0 is negative).
        """

        # Filter labels within probability range
        filtered = [x[0] for x in zip(true, probs) if x[1] > conf_thresh_lower and x[1] <= conf_thresh_upper]
        nr_elems = len(filtered)  # Number of elements in the list.

        if nr_elems < 1:
            return 0
        else:
            # In essence the confidence equals to the average accuracy of a bin
            conf = sum(filtered) / nr_elems  # Sums positive classes
            return conf

    def fit(self, probs, true):
        """
        Fit the calibration model, finding optimal confidences for all the bins.

        Params:
            probs: probabilities of data
            true: true labels of data
        """

        bin_value = []

        # Got through intervals and add confidence to list
        for conf_thresh in self.upper_bounds:
            temp_conf = self._get_conf((conf_thresh - self.bin_size), conf_thresh, probs=probs, true=true)
            bin_value.append(temp_conf)

        self.bin_value = np.array(bin_value)

    def predict(self, logits):
        idxs = np.searchsorted(self.upper_bounds, logits)
        return torch.from_numpy(self.bin_value[idxs])


def calibrate(train_logits, train_labels, test_logits, *args, **kwargs):
    if train_logits.numel() == 0 or train_labels.numel() == 0 or test_logits.numel() == 0:
        return {"logits":test_logits}

    model = HistogramBinning(N_bin=10)
    model.fit(train_logits, train_labels)
    return {"logits": model.predict(test_logits)}
