import numpy as np
import matplotlib.pyplot as plt
#import seaborn as sns
#from utils_old import binary_assessment
from utils_old import binary_calibration_old
from dirichletcal.calib import vectorscaling
from dirichletcal.calib import fulldirichlet

def bin_points(x, n_bins):
    x = np.clip(x,0,1)
    bin_upper_edges = np.linspace(0,1,n_bins+1)[1:-1]
    return np.sum(x.reshape((-1, 1)) >= bin_upper_edges, axis=1)

def bin_points_uniform(x, n_bins):
    x = x.squeeze()
    n_points = x.size
    bin_upper_edges = np.interp(np.linspace(0, n_points, n_bins+1),
                             np.arange(n_points),
                             np.sort(x))[1:]
    assert(bin_upper_edges.size == n_bins)
    bin_upper_edges[n_bins - 1] = np.inf
    return np.sum(x.reshape((-1, 1)) >= bin_upper_edges, axis=1)

def bin_points_discrete(x):
    x = x.squeeze()
    assert(len(np.unique(x))
               <= (x.shape[0]/2))
    bin_values = np.unique(x)
    return np.array([np.argwhere(x[i] == bin_values) for i in range(len(x))]).squeeze()

class identity():
    def predict_proba(self, x):
        return x
    def predict(self, x):
        return np.argmax(x, axis=1)

def lambda_factory(l, clf):
    return lambda x: clf.predict_proba(x)[:,l]
    
def classwise_ece(data_folder, bin_lower, bin_upper):
    assert(bin_upper >= bin_lower)
    
    #sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
    x_calib, x_test, temp_x_test, y_calib, y_test, temp_y_test = \
        np.load(data_folder + '/val_logits.npy'),\
        np.load(data_folder + '/test_logits.npy'),\
        np.load(data_folder + '/test_logits_temp.npy'),\
        np.load(data_folder + '/val_labels.npy'),\
        np.load(data_folder + '/test_labels.npy'),\
        np.load(data_folder + '/test_labels_temp.npy')

    y_calib = y_calib.astype('int')
    y_test = y_test.astype('int')
    x_calib = np.exp(x_calib)/np.tile((np.reshape(np.sum(np.exp(x_calib), axis=1), 
                                              (x_calib.shape[0],1))), (1,x_calib.shape[1]))
    x_test = np.exp(x_test)/np.tile((np.reshape(np.sum(np.exp(x_test), axis=1), 
                                            (x_test.shape[0],1))), (1,x_test.shape[1]))

    temp_x_test = np.exp(temp_x_test)/\
        np.tile((np.reshape(np.sum(np.exp(temp_x_test), axis=1),
                            (temp_x_test.shape[0],1))), (1,temp_x_test.shape[1]))

    clf = identity()
    
    #bin_range = np.array([10, 15])
    bin_range = np.arange(bin_lower, bin_upper + 1)
    
    ECE_classwise_uncalib = np.zeros(bin_range.size)
    ECE_classwise_temp = np.zeros(bin_range.size)
    ECE_classwise_hist = np.zeros(bin_range.size)
    ECE_classwise_hist_normalized = np.zeros(bin_range.size)
    ECE_classwise_vs = np.zeros(bin_range.size)
    #ECE_classwise_platt = np.zeros(bin_range.size)
    ECE_classwise_ds = np.zeros(bin_range.size)

    vs = vectorscaling.VectorScaling()
    vs.fit(x_calib, y_calib)
    preds_test_vs = vs.predict_proba(x_test)

    ds = fulldirichlet.FullDirichletCalibrator()    
    ds.fit(x_calib, y_calib)
    preds_test_ds = ds.predict_proba(x_test)    
    
    for bin_ind in range(bin_range.size):
        n_bins = bin_range[bin_ind]
        print(n_bins)
        ECE_classwise_ds[bin_ind] = compute_ece_continuous(
            preds_test_ds, y_test, n_bins)

        ECE_classwise_uncalib[bin_ind] = compute_ece_continuous(
            x_test, y_test, n_bins)
        ECE_classwise_temp[bin_ind] = compute_ece_continuous(
            temp_x_test, temp_y_test, n_bins)
        ECE_classwise_vs[bin_ind] = compute_ece_continuous(
            preds_test_vs, y_test, n_bins)
        ECE_classwise_ds[bin_ind] = compute_ece_continuous(
            preds_test_ds, y_test, n_bins)
        
        n_labels = x_test.shape[1]
        umd = [0]*n_labels
        #platt = [0]*n_labels
        
        for l in range(n_labels):
            #l_indices = np.argwhere(clf.predict(x_calib) == l)[:,0]
            umd[l] = binary_calibration_old.recalibrated_classifier()
            umd[l].base_clf = lambda_factory(l, clf)#lambda x: clf.predict_proba(x)[:,l]
            umd[l].n_bins = n_bins
            umd[l].calibration_type = 'umd'
            umd[l].fit(x_calib, (y_calib==l).astype('int'))

            #platt[l] = binary_calibration_old.recalibrated_classifier()
            #platt[l].base_clf = lambda_factory(l, clf)#lambda x: clf.predict_proba(x)[:,l]
            #platt[l].calibration_type = 'platt'
            #platt[l].fit(x_calib, (y_calib==l).astype('int'))

        ECE_classwise_hist[bin_ind] = compute_ece_umd(
            x_test, y_test, umd, n_bins)

        full_pred_matrix = np.zeros(x_test.shape)
        for l in range(n_labels):
            full_pred_matrix[:,l] = umd[l].predict_proba(x_test)
        row_sums = np.sum(full_pred_matrix, axis=1)
        full_pred_matrix = full_pred_matrix/np.repeat(row_sums[:,np.newaxis], n_labels, 1)
        ECE_classwise_hist_normalized[bin_ind] = compute_ece_continuous(
            full_pred_matrix, y_test, n_bins)

        #full_pred_matrix = np.zeros(x_test.shape)
        #for l in range(n_labels):
        #    full_pred_matrix[:,l] = platt[l].predict_proba(x_test)
        #ECE_classwise_platt[bin_ind] = compute_ece_continuous(
        #    full_pred_matrix, y_test, n_bins)
        
    return ECE_classwise_uncalib, ECE_classwise_temp, ECE_classwise_hist, ECE_classwise_hist_normalized, ECE_classwise_vs, ECE_classwise_ds#, ECE_classwise_platt

def compute_ece_continuous(x_test, y_test, n_bins):
    clf = identity()
    n_labels = x_test.shape[1]
    N_b = np.zeros(n_bins)
    Delta_b = np.zeros(n_bins)
    Delta_b_gpsw = np.zeros(n_bins)

    for l in range(n_labels):
        #l_inds = np.argwhere(clf.predict(x_test) == l)
        bin_inds = bin_points(x_test[:, l], n_bins)
        pred_hat = np.array([np.mean(x_test[bin_inds == b, l])
                             if np.sum(bin_inds == b) > 0 \
                             else 1.0/n_labels
                             for b in range(n_bins)])
        y_hat = np.array([np.mean(y_test[bin_inds == b] == l)
                          if np.sum(bin_inds == b) > 0 \
                          else 1.0/n_labels
                          for b in range(n_bins)])
        n_b = np.array([np.sum(bin_inds == b) for b in range(n_bins)])
        N_b += n_b
        
        Delta_b += n_b * np.abs(y_hat - pred_hat)
    
    ECE_classwise = np.sum(Delta_b)/np.sum(N_b)
    
    assert(np.sum(N_b) == x_test.size)
    return ECE_classwise
    return ECE_classwise

def compute_ece_umd(x_test, y_test, umd, n_bins):
    clf = identity()
    n_labels = x_test.shape[1]
    N_b = np.zeros(n_bins)
    Delta_b = np.zeros(n_bins)
    Delta_b_gpsw = np.zeros(n_bins)
    for l in range(n_labels):
        l_preds = umd[l].predict_proba(x_test)
        bin_inds = bin_points_discrete(l_preds)
        #n_bins = np.unique(bin_inds).size
        pred_hat = np.array([np.mean(l_preds[bin_inds == b])\
                             if np.sum(bin_inds == b) > 0 
                             else 1.0/n_labels
                             for b in range(n_bins)])
        y_hat = np.array([np.mean(y_test[bin_inds == b] == l)
                          if np.sum(bin_inds == b) > 0 
                          else 1.0/n_labels
                          for b in range(n_bins)])
        n_b = np.array([np.sum(bin_inds == b) for b in range(n_bins)])
        N_b += n_b

        Delta_b += n_b * np.abs(y_hat - pred_hat)

    assert(np.sum(N_b) == x_test.size)
    ECE_classwise = np.sum(Delta_b)/np.sum(N_b)

    return ECE_classwise

def compute_ece_umd_normalized(x_test, y_test, umd, n_bins):
    clf = identity()
    n_labels = x_test.shape[1]
    N_b = np.zeros(n_bins)
    Delta_b = np.zeros(n_bins)
    Delta_b_gpsw = np.zeros(n_bins)

    full_pred_matrix = np.zeros(x_test.shape)
    for l in range(n_labels):
        full_pred_matrix[:,l] = umd[l].predict_proba(x_test)

    row_sums = np.sum(full_pred_matrix, axis=1)
    #print(np.repeat(row_sums[:,np.newaxis], n_labels, 1).shape)
    full_pred_matrix = full_pred_matrix/np.repeat(row_sums[:,np.newaxis], n_labels, 1)
    #print(full_pred_matrix[:10,:10])
    
    for l in range(n_labels):
        l_preds = full_pred_matrix[:,l]
        bin_inds = bin_points_discrete(l_preds)
        #n_bins = np.unique(bin_inds).size
        pred_hat = np.array([np.mean(l_preds[bin_inds == b])\
                             if np.sum(bin_inds == b) > 0 
                             else 1.0/n_labels
                             for b in range(n_bins)])
        y_hat = np.array([np.mean(y_test[bin_inds == b] == l)
                          if np.sum(bin_inds == b) > 0 
                          else 1.0/n_labels
                          for b in range(n_bins)])
        n_b = np.array([np.sum(bin_inds == b) for b in range(n_bins)])
        N_b += n_b

        Delta_b += n_b * np.abs(y_hat - pred_hat)

    assert(np.sum(N_b) == x_test.size)
    ECE_classwise = np.sum(Delta_b)/np.sum(N_b)

    return ECE_classwise
