import numpy as np
import matplotlib.pyplot as plt
#import seaborn as sns
#from utils import binary_calibration
from utils_old import binary_calibration_old
from dirichletcal.calib import vectorscaling
from dirichletcal.calib import fulldirichlet

def bin_points(x, n_bins):
    bin_upper_edges = np.linspace(0,1,n_bins+1)[1:]
    bin_upper_edges = np.linspace(0,1,n_bins+1)[1:-1]
    return np.sum(x.reshape((-1, 1)) > bin_upper_edges, axis=1)

def bin_points_uniform(x, n_bins):
    x = x.squeeze()
    n_points = x.size
    bin_upper_edges = np.interp(np.linspace(0, n_points, n_bins+1),
                             np.arange(n_points),
                             np.sort(x))[1:]
    assert(bin_upper_edges.size == n_bins)
    bin_upper_edges[n_bins - 1] = np.inf
    return np.sum(x.reshape((-1, 1)) > bin_upper_edges, axis=1)

def bin_points_discrete(x):
    x = x.squeeze()
    assert(len(np.unique(x))
               <= (x.shape[0]/2))
    bin_values = np.unique(x)
    return np.array([np.argwhere(x[i] == bin_values) for i in range(len(x))]).squeeze()

class identity():
    def predict_proba(self, x):
        return x
    def predict(self, x):
        return np.argmax(x, axis=1)

def lambda_factory(l, clf):
    return lambda x: clf.predict_proba(x)[:,l]
    
def toplabel(data_folder, bin_lower, bin_upper, normalized_HB = False):
    assert(bin_upper >= bin_lower)
    
    #sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
    x_calib, x_test, temp_x_test, y_calib, y_test, temp_y_test = \
        np.load(data_folder + '/val_logits.npy'),\
        np.load(data_folder + '/test_logits.npy'),\
        np.load(data_folder + '/test_logits_temp.npy'),\
        np.load(data_folder + '/val_labels.npy'),\
        np.load(data_folder + '/test_labels.npy'),\
        np.load(data_folder + '/test_labels_temp.npy')

    y_calib = y_calib.astype('int')
    y_test = y_test.astype('int')
    x_calib = np.exp(x_calib)/np.tile((np.reshape(np.sum(np.exp(x_calib), axis=1), 
                                              (x_calib.shape[0],1))), (1,x_calib.shape[1]))
    x_test = np.exp(x_test)/np.tile((np.reshape(np.sum(np.exp(x_test), axis=1), 
                                            (x_test.shape[0],1))), (1,x_test.shape[1]))

    temp_x_test = np.exp(temp_x_test)/\
        np.tile((np.reshape(np.sum(np.exp(temp_x_test), axis=1),
                            (temp_x_test.shape[0],1))), (1,temp_x_test.shape[1]))

    clf = identity()
    
    bin_range = np.arange(bin_lower, bin_upper + 1)
    
    ECE_toplabel_uncalib = np.zeros(bin_range.size)
    ECE_toplabel_temp = np.zeros(bin_range.size)
    ECE_toplabel_hist = np.zeros(bin_range.size)
    ECE_toplabel_hist_normalized = np.zeros(bin_range.size)
    ECE_toplabel_vs = np.zeros(bin_range.size)
    ECE_toplabel_ds = np.zeros(bin_range.size)

    vs = vectorscaling.VectorScaling()
    vs.fit(x_calib, y_calib)
    preds_test_vs = vs.predict_proba(x_test)

    ds = fulldirichlet.FullDirichletCalibrator()    
    ds.fit(x_calib, y_calib)
    preds_test_ds = ds.predict_proba(x_test)        

    for bin_ind in range(bin_range.size):
        n_bins = bin_range[bin_ind]
        print(n_bins)
        ECE_toplabel_uncalib[bin_ind] = compute_ece_continuous(
            x_test, y_test, n_bins)
        ECE_toplabel_temp[bin_ind] = compute_ece_continuous(
            temp_x_test, temp_y_test, n_bins)
        ECE_toplabel_vs[bin_ind] = compute_ece_continuous(
            preds_test_vs, y_test, n_bins)
        ECE_toplabel_ds[bin_ind] = compute_ece_continuous(
            preds_test_ds, y_test, n_bins)

        n_labels = x_test.shape[1]
        umd = [0]*n_labels
        platt = [0]*n_labels
        
        for l in range(n_labels):
            l_indices = np.argwhere(clf.predict(x_calib) == l)[:,0]
            umd[l] = binary_calibration_old.recalibrated_classifier()
            umd[l].base_clf = lambda_factory(l, clf)#lambda x: clf.predict_proba(x)[:,l]
            umd[l].n_bins = n_bins
            umd[l].calibration_type = 'umd'
            umd[l].fit(x_calib[l_indices], (y_calib[l_indices]==l).astype('int'))
            
        ECE_toplabel_hist[bin_ind] = compute_ece_umd(
            x_test, y_test, umd, n_bins)

        if(normalized_HB == True):
            for l in range(n_labels):
                umd[l] = binary_calibration_old.recalibrated_classifier()
                umd[l].base_clf = lambda_factory(l, clf)#lambda x: clf.predict_proba(x)[:,l]
                umd[l].n_bins = n_bins
                umd[l].calibration_type = 'umd'
                umd[l].fit(x_calib, (y_calib==l).astype('int'))

            full_pred_matrix = np.zeros(x_test.shape)
            for l in range(n_labels):
                full_pred_matrix[:,l] = umd[l].predict_proba(x_test)
            row_sums = np.sum(full_pred_matrix, axis=1)
            full_pred_matrix = full_pred_matrix/np.repeat(row_sums[:,np.newaxis], n_labels, 1)
            ECE_toplabel_hist_normalized[bin_ind] = compute_ece_continuous(
                full_pred_matrix, y_test, n_bins)
    if(normalized_HB == False):
        return ECE_toplabel_uncalib, ECE_toplabel_temp, ECE_toplabel_hist, ECE_toplabel_vs, ECE_toplabel_ds    
    else:
        return ECE_toplabel_uncalib, ECE_toplabel_temp, ECE_toplabel_hist, ECE_toplabel_vs, ECE_toplabel_ds, ECE_toplabel_hist_normalized

def compute_ece_continuous(x_test, y_test, n_bins):
    clf = identity()
    n_labels = x_test.shape[1]
    N_b = np.zeros(n_bins)
    Delta_b = np.zeros(n_bins)
    for l in range(n_labels):
        l_inds = np.argwhere(clf.predict(x_test) == l)
        bin_inds = bin_points(x_test[l_inds, l], n_bins)
        pred_hat = np.array([np.mean(x_test[l_inds[bin_inds == b], l])
                             if np.sum(bin_inds == b) > 0 \
                             else 1.0/n_labels
                             for b in range(n_bins)])
        y_hat = np.array([np.mean(y_test[l_inds[bin_inds == b]] == l)
                          if np.sum(bin_inds == b) > 0 \
                          else 1.0/n_labels
                          for b in range(n_bins)])
        n_b = np.array([np.sum(bin_inds == b) for b in range(n_bins)])
        N_b += n_b
        
        Delta_b += n_b * np.abs(y_hat - pred_hat)
    
    ECE_toplabel = np.sum(Delta_b)/np.sum(N_b)
    
    return ECE_toplabel

def compute_ece_umd(x_test, y_test, umd, n_bins):
    clf = identity()
    n_labels = x_test.shape[1]
    N_b = np.zeros(n_bins)
    Delta_b = np.zeros(n_bins)
    for l in range(n_labels):
        l_inds = np.argwhere(clf.predict(x_test) == l)
        l_preds = umd[l].predict_proba(x_test[l_inds].squeeze())
        bin_inds = bin_points_discrete(l_preds)
        #n_bins = np.unique(bin_inds).size
        pred_hat = np.array([np.mean(l_preds[bin_inds == b])\
                             if np.sum(bin_inds == b) > 0 
                             else 1.0/n_labels
                             for b in range(n_bins)])
        y_hat = np.array([np.mean(y_test[l_inds[bin_inds == b]] == l)
                          if np.sum(bin_inds == b) > 0 
                          else 1.0/n_labels
                          for b in range(n_bins)])
        n_b = np.array([np.sum(bin_inds == b) for b in range(n_bins)])
        N_b += n_b

        Delta_b += n_b * np.abs(y_hat - pred_hat)

    ECE_toplabel = np.sum(Delta_b)/np.sum(N_b)

    return ECE_toplabel

