import numpy as np
import cvxpy as cp
from sklearn.exceptions import ConvergenceWarning
import random

def get_top_probs(X, model):
    # check whether there is a single impage
    if len(X.shape) == 3:
        pred_vect = model.predict(X.reshape([1, 32, 32, 3]))
    else:
        pred_vect = model.predict(X)
    return np.max(pred_vect.transpose(), axis=0)


def get_uniform_mass_bins(probs, n_bins):
    """
    Function that computes bins' centers and widths for uniform mass binning

    Parameters
    ----------
        probs:
            array containing probabilities that have to be evenly split into groups

        n_bins:
            number of bins for approximately even splitting

    Returns
    ----------
        bin_centers: 
            array containing bins's centers

        bin_width:
            array containing bins' width

    """

    probs_sorted = np.sort(probs)

    # split probabilities into groups of approx equal size
    groups = np.array_split(probs_sorted, n_bins)
    bin_edges = list()
    bin_centers = list()
    bin_width = list()

    for cur_group in range(n_bins-1):        
        bin_edges += [(max(groups[cur_group])+min(groups[cur_group+1]))/2]

    '''
    for cur_group in range(n_bins):
        # separately consider the first group
        if cur_group == 0:
            # get right boundary of the first interval
            left_boundary = 0
            # here
            right_boundary = (
                max(groups[cur_group])+min(groups[cur_group+1]))/2

            bin_centers += [(left_boundary+right_boundary) / 2]
            bin_width += [right_boundary-left_boundary]
        elif cur_group != n_bins-1:

            left_boundary = right_boundary
            right_boundary = (
                max(groups[cur_group])+min(groups[cur_group+1]))/2
            bin_centers += [(left_boundary+right_boundary) / 2]
            bin_width += [right_boundary-left_boundary]
        else:
            # if this is the last group
            left_boundary = right_boundary
            right_boundary = 1
            bin_centers += [(left_boundary+right_boundary) / 2]
            bin_width += [right_boundary-left_boundary]

    return np.array(bin_centers), np.array(bin_width)
    '''

    return np.array(bin_edges)

def bin_points(scores, bin_edges):
    assert(bin_edges is not None), "Bins have not been defined"
    assert(scores.shape[0] == scores.size), "scores should be a 1D vector"
    scores = np.reshape(scores, (scores.size, 1))
    bin_edges = np.reshape(bin_edges, (1, bin_edges.size))
    return np.sum(scores >= bin_edges, axis=1)

def logistic_transform(x):
    """
    Function that performs logistic transformation 

    """
    return 1 / (1 + np.exp(x))


def platt_scaling(scores, labels, regularization=True):
    """
    Function that performs Platt scaling

    Parameters
    ----------
        scores:
            array of scores / output of uncalibrated classifier

        labels:
            array of true labels (0/1-labeling is assumed)

        regularization: bool
            Indicator of whether regularization should be applied

    Returns
    ----------
        calibrated_probs:

        A.value, B.value:
            Values corresponding for Platt scaling

    """

    if regularization:
        # compute proportions in the calibration set in case of regularization being applied
        N_1 = (labels == 1).sum()
        N_0 = (labels == 0).sum()
        t = (labels == 1).astype('int') * (N_1 + 1) / (N_1 + 2) + (
            labels == 0).astype('int') * 1.0 / (N_0 + 2)
    else:
        # just use raw labels
        t = np.copy(labels)

    A = cp.Variable(1)
    B = cp.Variable(1)

    # form an objective for maximization
    s = A * scores + B
    objective = t @ (-s) - np.ones_like(t) @ cp.logistic(-s)

    # solve the problem
    problem = cp.Problem(cp.Maximize(objective))
    problem.solve()

    if problem.status == 'optimal':
        calibrated_probs = logistic_transform(A.value * scores + B.value)
        return calibrated_probs, A.value, B.value
    else:
        raise ConvergenceWarning("CVXPY hasn't converged")

def nudge(matrix, delta):
    return((matrix + np.random.uniform(low=0, high=delta, size=(matrix.shape)))/(1+delta))

class recalibrated_classifier(object):
    def __init__(self):
        self.base_clf = None
        self.n_bins = None
        self.mean_pred_values = None
        self.bin_centers = None
        self.bin_width = None
        self.bin_edges = None
        self.num_calibration_examples_in_bin = None
        self.calibration_type = None

        self.platt_regularization = True
        self.A = None
        self.B = None
        self.tce = False
        self.fitted = False
        self.delta = 1e-10

    def fit(self, X, y, probs_provided=None):
        """
        Method used to fit a wrapper around classifier
        """

        if self.base_clf is None:
            raise ValueError("Base Classifier has to be provided")
        if self.calibration_type is None:
            raise ValueError("Calibration type has to be specified first")

        if self.tce == False:
            raw_probs = self.base_clf(X)
            if len(raw_probs.shape) == 1:
                y_score = raw_probs
            elif raw_probs.shape[1] == 2:
                y_score = raw_probs[:, 1]
            else:
                y_score = raw_probs.ravel()
        else:
            y_score = probs_provided

        cal_size = len(y_score)
        indices = np.arange(cal_size).astype('int')

        if self.calibration_type in ['fixed_width', 'uniform_mass', 'umd', 'umd_soft']:
            if self.n_bins is None:
                raise ValueError("Number of bins has to be specified")

            if self.calibration_type == 'fixed_width':
                width = 1.0 / self.n_bins
                #self.bin_centers = np.linspace(0, 1.0 - width,
                #                               self.n_bins) + width / 2
                #self.bin_width = np.repeat(width, self.n_bins)
                self.bin_edges = np.linspace(width, 1.0 - width, self.n_bins-1)

            elif self.calibration_type == 'uniform_mass':
                # randomly split calibration set into 2 equal parts
                # use half for uniform mass binning
                random.shuffle(indices)
                probs = y_score[indices[:int(cal_size/2)]]

                #self.bin_centers, self.bin_width = get_uniform_mass_bins(
                #    probs, self.n_bins)
                self.bin_edges = get_uniform_mass_bins(probs, self.n_bins)
                
                indices = indices[int(cal_size/2):]
                # only use the second part of the data for calibration
                y_score = y_score[indices]

            elif self.calibration_type in ['umd', 'umd_soft']:
                # use all data for binning as well as bias estimation
                #y_score = (y_score + (self.delta * np.random.uniform(size=y_score.shape)).ravel())/(1 + self.delta)
                y_score = nudge(y_score, self.delta)
                self.bin_edges = get_uniform_mass_bins(y_score, self.n_bins)

            # true labels to be used for calibration
            if(self.calibration_type != 'umd_soft'):
                y_calib = y[indices]
            else:
                y_calib = y_score[indices]

            self.num_calibration_examples_in_bin = np.zeros([self.n_bins, 1])
            bin_assignment = bin_points(y_score, self.bin_edges)
            #bin_assignment = np.empty(np.size(y_calib))
            self.mean_pred_values = np.empty(self.n_bins)

            '''
            for i, threshold in enumerate(self.bin_centers):
                # determine all samples where y_score falls into the i-th bin
                bin_idx = np.logical_and(threshold - self.bin_width[i] / 2 - 1e-30 <= y_score,
                                         y_score <= threshold + self.bin_width[i] / 2 + 1e-30)
                bin_assignment[bin_idx] = i
            '''
            
            for i in range(self.n_bins):
                bin_idx = (bin_assignment == i)
                self.num_calibration_examples_in_bin[i] = sum(bin_idx)
                if (sum(bin_idx) > 0):
                    self.mean_pred_values[i] = nudge(y_calib[bin_idx].mean(), self.delta)
                else:
                    self.mean_pred_values[i] = nudge(0.5, self.delta)

            self.fitted = True


        elif self.calibration_type == 'isotonic':
            y_score_sorted_inds = np.argsort(y_score)
            y_score = y_score[y_score_sorted_inds]
            g = y[y_score_sorted_inds].astype('float')

            flag1 = True
            while(flag1 == True):  # while g is not sorted
                leftmost_nonincreasing_index = 0
                flag1 = False
                flag2 = False
                for i in range(g.size-1):
                    if(g[i] < g[i+1]):
                        if(flag2):
                            num_indices_to_set = (i + 1 -
                                                  leftmost_nonincreasing_index)
                            g[leftmost_nonincreasing_index:(i+1)] = np.sum(
                                g[leftmost_nonincreasing_index:(i+1)])/num_indices_to_set
                            flag2 = False
                            leftmost_nonincreasing_index = i+1
                    if(g[i] > g[i+1]):
                        flag1 = True
                        flag2 = True
                    elif(g[i] == g[i+1]):
                        pass
                if(flag2):
                    num_indices_to_set = (g.size -
                                          leftmost_nonincreasing_index)
                    g[leftmost_nonincreasing_index:g.size] = np.sum(
                        g[leftmost_nonincreasing_index:g.size])/num_indices_to_set

            self.bin_centers = np.array([])
            self.bin_width = np.array([])
            self.mean_pred_values = np.array([])
            self.num_calibration_examples_in_bin = np.array([])
            prev_score = 0.0
            prev_val = g[0]
            prev_ind = 0

            for i in range(g.size):
                if(g[i] != prev_val):
                    self.bin_centers = np.append(self.bin_centers,
                                                 (y_score[i]+prev_score)/2)
                    self.bin_width = np.append(self.bin_width,
                                               y_score[i]-prev_score)
                    self.mean_pred_values = np.append(self.mean_pred_values,
                                                      prev_val)
                    self.num_calibration_examples_in_bin = np.append(
                        self.num_calibration_examples_in_bin,
                        i - prev_ind)
                    prev_ind = i
                    prev_score = y_score[i]
                    prev_val = g[i]

            if(1.0 != prev_score):
                self.bin_centers = np.append(self.bin_centers,
                                             (1+prev_score)/2)
                self.bin_width = np.append(self.bin_width,
                                           1-prev_score)
                self.mean_pred_values = np.append(self.mean_pred_values,
                                                  prev_val)
                self.num_calibration_examples_in_bin = np.append(
                    self.num_calibration_examples_in_bin,
                    g.size - prev_ind)

            self.bin_edges = (self.bin_centers + self.bin_width/2)[:-1]
            self.n_bins = self.bin_centers.size
            self.fitted = True

        elif self.calibration_type == 'platt':
            _, self.A, self.B = platt_scaling(
                y_score, y, self.platt_regularization)
            self.fitted = True

        elif self.calibration_type == 'platt_binning':
            if self.n_bins is None:
                raise ValueError("Number of bins has to be specified")

            # randomly split calibration set into 3 equal parts
            random.shuffle(indices)

            # use first part to train platt scaler
            probs_1 = y_score[indices[:int(cal_size/3)]]
            y_true_1 = y[indices[:int(cal_size / 3)]]
            _, self.A, self.B = platt_scaling(
                probs_1, y_true_1, self.platt_regularization)

            # use second part to perform uniform-mass-binning
            probs_2 = y_score[indices[int(
                cal_size/3):int(2*cal_size/3)]]
            # recalibrate probabilities
            probs_2 = logistic_transform(self.A * probs_2 + self.B)

            # get bins centers and width
            #self.bin_centers, self.bin_width = get_uniform_mass_bins(
            #    probs_2, self.n_bins)
            self.bin_edges = get_uniform_mass_bins(probs_2, self.n_bins)
                
            # output mean values in each bin

            probs_3 = y_score[indices[int(2*cal_size/3):]]
            # recalibrate probabilities
            probs_3 = logistic_transform(self.A * probs_3 + self.B)

            self.num_calibration_examples_in_bin = np.zeros([self.n_bins, 1])
            bin_assignment = bin_points(probs_3, self.bin_edges)
            #bin_assignment = np.empty(np.size(probs_3))
            self.mean_pred_values = np.empty(self.n_bins)

            '''
            for i, threshold in enumerate(self.bin_centers):
                # determine all samples where probs_3 falls into the i-th bin
                bin_idx = np.logical_and(threshold - self.bin_width[i] / 2 - 1e-30<= probs_3,
                                         probs_3 <= threshold + self.bin_width[i] / 2 + 1e-30)
                bin_assignment[bin_idx] = i
            '''
            
            for i in range(self.n_bins):
                bin_idx = (bin_assignment == i)
                self.num_calibration_examples_in_bin[i] = sum(bin_idx)
                # Store mean empirical probability of positive class,
                # average of upper and lower limits of the interval
                if (sum(bin_idx) > 0):
                    self.mean_pred_values[i] = (
                        probs_3[bin_idx].sum() + 0.5) / (sum(bin_idx)+1)
                else:
                    self.mean_pred_values[i] = 0.5
            self.fitted = True

        else:
            raise ValueError("Unknown calibration type specified")

    def predict_proba(self, X, probs_provided=None):
        """
        Method used to predict probabilities for a given dataset
        """
        if self.fitted is False:
            raise ValueError("Classifier has to be fitted first")

        # get uncalibrated predictions

        if self.tce == False:
            raw_probs = self.base_clf(X)
            if len(raw_probs.shape) == 1:
                y_score = raw_probs
            elif raw_probs.shape[1] == 2:
                y_score = raw_probs[:, 1]
            else:
                y_score = raw_probs.ravel()
        else:
            y_score = probs_provided

        if self.calibration_type in ['fixed_width', 'uniform_mass', 'isotonic', 'umd', 'umd_soft']:
            if(self.calibration_type in ['umd', 'umd_soft']):
                y_score = nudge(y_score, self.delta)
            # get bins' indices
            y_bins = bin_points(y_score, self.bin_edges)
            
            # get calibrated predicted probabilities
            y_pred_proba = self.mean_pred_values[y_bins]
            return y_pred_proba
        
            '''
            y_bins = np.empty(y_score.shape[0], dtype='int')
            for i, threshold in enumerate(self.bin_centers):
                # determine all samples where y_score falls into the i-th bin
                bin_idx = np.logical_and(threshold - (self.bin_width[i]/2) <= y_score,
                                         y_score < threshold + (self.bin_width[i]/2))
                # Store bins' numbers
                y_bins[bin_idx] = i
                zz = zz + np.sum(bin_idx)
            '''

        elif self.calibration_type == 'platt':
            return logistic_transform(self.A * y_score + self.B)


        elif self.calibration_type == 'platt_binning':
            y_score = logistic_transform(self.A * y_score + self.B)
            y_bins = np.empty(y_score.shape[0], dtype='int')

            for i, threshold in enumerate(self.bin_centers):
                # determine all samples where y_score falls into the i-th bin
                bin_idx = np.logical_and(threshold - (self.bin_width[i]/2) - 1e-30 <= y_score,
                                         y_score <= threshold + (self.bin_width[i]/2) + 1e-30)
                # Store bins' numbers
                y_bins[bin_idx] = i
            # get calibrated predicted probabilities
            y_pred_proba = self.mean_pred_values[y_bins]
            return y_pred_proba

    def predict(self, X, probs_provided=None):
        """
        Method used to predict labels for a given dataset
        """
        return (self.predict_proba(X, probs_provided) >= 0.5).astype('int')

    def score(self, X, y, probs_provided=None):
        """
        Method used to evaluate accuracy of a classifier
        """
        y_pred = self.predict(X, probs_provided)
        return (y_pred == y).mean()

