import numpy as np

def eval_spread_error_multi(f_cal, C_cal,f_cal_aug_save,C_cal_aug_save,f_test,C_test, nbins, peak):
    
    edges = determine_edges(f_cal,nbins)
    we, pt = histogram_eval(f_cal, C_cal, edges)
    edges_cal_aug_save = {0:edges}
    pt_cal_aug_save = {0:pt}
    for min_intensity in range(1,6):
        edges_cal_aug_save[min_intensity] = determine_edges(f_cal_aug_save[min_intensity],nbins)
        _, pt_cal_aug_save[min_intensity] = histogram_eval(f_cal_aug_save[min_intensity], C_cal_aug_save[min_intensity], edges_cal_aug_save[min_intensity])

    new_edges = np.unique(np.concatenate(list(edges_cal_aug_save[min_intensity] for min_intensity in range(0,6))))
    new_pt_cal_aug_save = {}
    new_likelyhood_cal_aug_save = {}
    for min_intensity in range(0,6):
        new_pt_cal_aug_save[min_intensity] = pt_cal_aug_save[min_intensity][np.digitize(new_edges[1:],edges_cal_aug_save[min_intensity],right=True)-1]
    # Mean Strategy
    ip_peak = np.random.permutation(len(f_test))
    if peak:
        ip_peak = ip_peak[:100]
    best = np.abs(np.mean(f_cal)-np.mean(f_test[ip_peak]))
    istar = 0
    for min_intensity in range(1,6):
        if np.abs(np.mean(f_cal_aug_save[min_intensity])-np.mean(f_test[ip_peak])) < best:
            istar = min_intensity
            best = np.abs(np.mean(f_cal_aug_save[min_intensity])-np.mean(f_test[ip_peak]))
            f_cal_aug = f_cal_aug_save[min_intensity]
            C_cal_aug = C_cal_aug_save[min_intensity]
    for min_intensity in range(0,6):
        new_likelyhood_cal_aug_save[min_intensity] = 0
    new_likelyhood_cal_aug_save[istar] = 1
    new_pt = np.zeros(len(new_edges)-1)
    likelyhood_total = sum([new_likelyhood_cal_aug_save[min_intensity] for min_intensity in range(0,6)])
    for i in range(len(new_edges)-1):
        if likelyhood_total>0:
            new_pt[i] = new_pt_cal_aug_save[istar][i]
        else:
            new_pt[i] += sum([new_pt_cal_aug_save[min_intensity][i] for min_intensity in range(0,6)])/6

    edges = determine_edges(f_test,nbins)
    pt = new_pt[np.digitize((edges[1:]+edges[:-1])/2,new_edges,right=True)-1]
    we, pe = histogram_eval(f_test, C_test, edges)
    # Calculate scores
    brier, nll, e_or, e_ab = scores = histogram_scores(we, pe, pt)
    
    # Calculate error metrics
    _, __, spreads, calc_errs = error_metrics_histogram(we, pe, pt)
    
    return (scores, spreads, calc_errs, we, edges, pt, pe)

def eval_spread_error_individual(f_cal, C_cal,f_cal_aug_save,C_cal_aug_save,f_test,C_test, nbins, strategy):
    
    edges = determine_edges(f_cal,nbins)
    _, pt = histogram_eval(f_cal, C_cal, edges)
    
    edges_cal_aug_save = {0:edges}
    pt_cal_aug_save = {0:pt}
    for min_intensity in range(1,6):
        edges_cal_aug_save[min_intensity] = determine_edges(f_cal_aug_save[min_intensity],nbins)
        _, pt_cal_aug_save[min_intensity] = histogram_eval(f_cal_aug_save[min_intensity], C_cal_aug_save[min_intensity], edges_cal_aug_save[min_intensity])
    
    if strategy == 1:
        for min_intensity in range(1,6):
            aux = (np.mean(f_cal_aug_save[min_intensity])+np.mean(f_cal_aug_save[min_intensity-1]))/2
            edges_left = edges_cal_aug_save[min_intensity][edges_cal_aug_save[min_intensity]<aux]
            edges_right = edges[edges>aux]
            edges = np.concatenate((edges_left,[aux],edges_right))
            pt = np.concatenate((pt_cal_aug_save[min_intensity][:len(edges_left)],pt[-len(edges_right):]))
    elif strategy == 2:
        new_edges = np.unique(np.concatenate(list(edges_cal_aug_save[min_intensity] for min_intensity in range(0,6))))
        new_pt_cal_aug_save = {}
        new_likelyhood_cal_aug_save = {}
        for min_intensity in range(0,6):
            new_pt_cal_aug_save[min_intensity] = pt_cal_aug_save[min_intensity][np.digitize(new_edges[1:],edges_cal_aug_save[min_intensity],right=True)-1]
            likelyhood_cal_aug_temp, edges_temp = np.histogram(f_cal_aug_save[min_intensity],bins=np.linspace(0,1,101))
            new_likelyhood_cal_aug_save[min_intensity] = likelyhood_cal_aug_temp[np.digitize(new_edges[1:],edges_temp,right=True)-1]/sum(likelyhood_cal_aug_temp)

        new_pt = np.zeros(len(new_edges)-1)
        for i in range(len(new_edges)-1):
            likelyhood_total = sum([new_likelyhood_cal_aug_save[min_intensity][i] for min_intensity in range(0,6)])
            if likelyhood_total>0:
                for min_intensity in range(0,6):
                    new_pt[i] += new_pt_cal_aug_save[min_intensity][i]*new_likelyhood_cal_aug_save[min_intensity][i]
                new_pt[i] /= likelyhood_total
            else:
                new_pt[i] += sum([new_pt_cal_aug_save[min_intensity][i] for min_intensity in range(0,6)])/6
        
        edges = determine_edges(f_test,nbins)
        pt = new_pt[np.digitize((edges[1:]+edges[:-1])/2,new_edges,right=True)-1]
        
        pt = new_pt[np.digitize((edges[1:]+edges[:-1])/2,new_edges,right=True)-1]
    elif strategy == 3:
        min_intensity_list = [0,5]
        new_edges = np.unique(np.concatenate(list(edges_cal_aug_save[min_intensity] for min_intensity in min_intensity_list)))
        new_pt_cal_aug_save = {}
        new_likelyhood_cal_aug_save = {}
        for min_intensity in min_intensity_list:
            new_pt_cal_aug_save[min_intensity] = pt_cal_aug_save[min_intensity][np.digitize(new_edges[:-1],edges_cal_aug_save[min_intensity][:-1])-1]
            likelyhood_cal_aug_temp, edges_temp = np.histogram(f_cal_aug_save[min_intensity],bins=np.linspace(0,1,101))
            new_likelyhood_cal_aug_save[min_intensity] = likelyhood_cal_aug_temp[np.digitize(new_edges[:-1],edges_temp[:-1])-1]/sum(likelyhood_cal_aug_temp)
        
        new_pt = np.zeros(len(new_edges)-1)
        for i in range(len(new_edges)-1):
            likelyhood_total = sum([new_likelyhood_cal_aug_save[min_intensity][i]/len(min_intensity_list) for min_intensity in min_intensity_list])
            if likelyhood_total>0:
                for min_intensity in min_intensity_list:
                    new_pt[i] += new_pt_cal_aug_save[min_intensity][i]*new_likelyhood_cal_aug_save[min_intensity][i]/len(min_intensity_list)
                new_pt[i] /= likelyhood_total
            else:
                new_pt[i] += new_pt_cal_aug_save[min_intensity][i]/len(min_intensity_list)
        
        pt = new_pt
        edges = new_edges

    we, pe = histogram_eval(f_test, C_test, edges)
    
    # Calculate scores
    brier, nll, e_or, e_ab = scores = histogram_scores(we, pe, pt)
    
    # Calculate error metrics
    _, __, spreads, calc_errs = error_metrics_histogram(we, pe, pt)
    
    return (scores, spreads, calc_errs, we, edges, pt, pe)

def eval_spread_error(f_val, C_val,f_test,C_test, nbins):
    
    # f_i's : 2d array of discrimination RVs
    #   (in matlab it is the name of the col/row)
    # C : correct array
    # nbins : number of bins
    
    '''
        First generate the histograms.
        then evaluate the errors
        '''

    # Determine bin edges (evenly spaced)
    edges = determine_edges(f_val,nbins)
    q = np.quantile(f_val, np.arange(1, nbins) / nbins)
    
#    # New temporary edges: [-inf, q's, inf]
#    # In order to find final edges (evenly weighted)
##    edges = np.empty(nbins + 1)
##    edges[0], edges[1:-1], edges[-1] = -np.inf, q, np.inf
#    if max(f_val)<1:
#        if min(f_val)>0:
#            edges = np.empty(nbins + 1)
#            edges[0], edges[1:-1], edges[-1] = 0, q, 1
#        else:
#            edges = np.empty(nbins)
#            edges[:-1], edges[-1] = q,1
#    else:
#        if min(f_val)>0:
#            edges = np.empty(nbins)
#            edges[0], edges[1:] = 0,q
#        else:
#            edges = q

    _, pt = histogram_eval(f_val, C_val, edges)
    we, pe = histogram_eval(f_test, C_test, edges)
    
    # Calculate scores
    brier, nll, e_or, e_ab = scores = histogram_scores(we, pe, pt)
    
    # Calculate error metrics
    _, __, spreads, calc_errs = error_metrics_histogram(we, pe, pt)
    
    return (scores, spreads, calc_errs, we, edges, pt, pe)



def compute_edges_pt(f,C,nbins):

    n = len(f)
    aux = np.round(np.linspace(0,n,nbins+1)).astype(int)
    idx = np.repeat(np.array([i for i in range(nbins)]),aux[1:]-aux[:-1])
    aux[-1] = aux[-2]
    sorting_ix = np.argsort(f)
    edges = np.sort(f)[aux]
    edges[0] = 0
    edges[-1] = 1
    # 'Loop' through bins to gather some statistics
    i = np.arange(0, nbins)
    id = idx == i.reshape((-1, 1))

    pA = C.mean()

    count = id.sum(axis = 1)
    correct = np.logical_and(id, C[sorting_ix]).sum(axis = 1)
    w = count / n
    p = correct/count
    return edges, p

def determine_edges(pvalues,nbins):
    # Determine bin edges (evenly spaced)
    q = np.quantile(pvalues, np.arange(1, nbins+1) / nbins)
    
    # New temporary edges: [-inf, q's, inf]
    # In order to find final edges (evenly weighted)
    #    edges = np.empty(nbins + 1)
    #    edges[0], edges[1:-1], edges[-1] = -np.inf, q, np.inf
    if max(pvalues)<1:
        if min(pvalues)>0:
            edges = np.empty(nbins + 1 + 1)
            edges[0], edges[1:-1], edges[-1] = 0, q, 1
        else:
            edges = np.empty(nbins+1)
            edges[:-1], edges[-1] = q,1
    else:
        if min(pvalues)>0:
            edges = np.empty(nbins+1)
            edges[0], edges[1:] = 0,q
        else:
            edges = q
    return np.unique(edges)

def histogram_scores(we, pe, pt):
    
    '''
        Write a good docstring here...
        '''
    
    # Given, histogram, predicted and actual probabilities, with weights
    
    # Brier(pe, pt) quadratic scoring error
    b = pe * (1 - pt) ** 2 + (1 - pe) * pt ** 2
    brier = np.dot(we, b)
    
    
    # NLL
    nl = -np.log(pt) * pe - np.log(1 - pt) * (1 - pe)
    nll = np.dot(we, nl)
    
    # Odds ratio
    oratio = (1 - pt) / pt * pe + pt / (1 - pt) * (1 - pe)
    e_or = np.dot(we, oratio)

    # L1
    ab = np.abs(1 - pt) * pe + np.abs(1 - pt)  * (1 - pe)
    e_ab = np.dot(we, ab)
    
    return (brier, nll, e_or, e_ab)

def error_metrics_histogram(we, pe, pt):
    
    '''
        Write a good docstring here...
        '''
    
    # comparing metrics on two histograms
    # inputs: pA, accuracy, should equal expected value of p
    # pE, we empirical probability and weights in each bin for given data
    # Outputs: error metrics in order ES = [EL1, EL2, ELOR];
    
    pa = np.dot(we, pe)
    
    # Calibration errors
    
    # Variance
    l2c = (pe - pt) ** 2
    
    # Expected error
    l1c = np.abs(pe - pt)
    
    # KL divergence
    klc = kla(pe, pt);
    
    # Spreads
    l2s = (pe - pa) ** 2
    l1s = np.abs(pe - pa)
    kls = kla(pa, pe)
    oratio = oddratio(pe, pa)
    
    # Brier Baseline
    v = pa * (1 - pa)
    brier = v + np.dot(we, l2c - l2s)
    
    # KL score baseline
    h = entropy(pa)
    nll = h + np.dot(we, klc - kls)
    
    spreads = np.array([    np.dot(we, l2s),
                        np.dot(we, l1s),
                        np.dot(we, kls),
                        np.dot(we, oratio),
                        ])
        
    calc_errs = np.array([  np.dot(we, l2c),
                            np.dot(we, l1c),
                            np.dot(we, klc),
                          ])
                        
    return (brier, nll, spreads, calc_errs)

def kla(p, q):
    
    '''
        Write a good docstring here...
        '''
    
    # This yields a warning about division by zero
    # for the inner workings of np.log2
    l = q * np.log2(p / q)
    r = (1 - q) * np.log2((1 - p) / (1 - q))
    return -(l + r)

def entropy(pa):
    l = pa * np.log2(pa)
    r = (1 - pa) * np.log2(1 - pa)
    return -(l + r)

def odds(p):
    return p / (1 - p)

def oddratio(p, q):
    odd_p = odds(p)
    odd_q = odds(q)
    return np.maximum(odd_p / odd_q, odd_q / odd_p)

def histogram_eval(f, C, edges):
    
    # Given:
    # f : array of values
    # C : "correct" array
    # edges : end points as quantiles
    
    # - histogram, determined by bin_edges
    # - discriminator values, f_i
    # - whether model is correct, C_i
    # - flat to remove zero weights
    # output:  P( C | x in bin_i )
    
    '''
        Write a good docstring here...
        '''
    
    n = len(f)
    
    # Determine where elements lie with respect to bins
    idx = np.digitize(f, edges, right = True)
    nedges = len(edges)
    nbins = nedges - 1
    
    # Bayesian Prior: every bin gets one element with prob (acc)
    # This avoids dividing by zero
    pA = C.mean()
    
    # 'Loop' through bins to gather some statistics
    i = np.arange(1, nbins + 1)
    id = idx == i.reshape((-1, 1))
    
    count = id.sum(axis = 1)
    correct = np.logical_and(id, C).sum(axis = 1)
    w = count / n
    
    # Add the (pA) weighted element to each element
    # If bin is empty, no new information is obtained (which is desired)
    p = (correct + pA) / (count + 1)
    return (w, p)
