import numpy as np
import six
from six.moves import range

def snoek_scores(probabilities, ground_truth, bins=15):
    probabilities = probabilities.flatten()
    ground_truth = ground_truth.flatten()
    bin_edges, accuracies, counts = bin_predictions_and_accuracies(probabilities, ground_truth, bins)
    bin_centers = bin_centers_of_mass(probabilities, bin_edges)
    num_examples = np.sum(counts)
#    print(bins)
#    print(counts)
#    print(bin_centers)
#    print(accuracies)
#    input('pause')
    ece = np.sum([(counts[i] / float(num_examples)) * np.sum(np.abs(bin_centers[i] - accuracies[i])) for i in range(bin_centers.size) if counts[i] > 0])
#    print(ece)
    brier = np.sum([(counts[i] / float(num_examples)) * np.sum(accuracies[i] * (1-bin_centers[i]) ** 2 + (1-accuracies[i]) * bin_centers[i] ** 2)
                          for i in range(bin_centers.size) if counts[i] > 0])
#    print(brier)
    nll = np.sum([(counts[i] / float(num_examples)) * np.sum(-np.log(bin_centers[i]-1e-6) * accuracies[i] -np.log(1-bin_centers[i]+1e-6)*(1-accuracies[i]))
                                          for i in range(bin_centers.size) if counts[i] > 0])
#    print(nll)
    return ece, brier, nll, counts / float(num_examples), bin_edges, bin_centers, accuracies

def expected_calibration_error(probabilities, ground_truth, bins=15):
    #
    #   Compute the expected calibration error of a set of preditions in [0, 1].
    #   Args:
    #     probabilities: A numpy vector of N probabilities assigned to each prediction
    #     ground_truth: A numpy vector of N ground truth labels in {0,1, True, False}
    #     bins: Number of equal width bins to bin predictions into in [0, 1], or
    #       an array representing bin edges.
    #   Returns:
    #     Float: the expected calibration error.
    #         #
    probabilities = probabilities.flatten()
    ground_truth = ground_truth.flatten()
    bin_edges, accuracies, counts = bin_predictions_and_accuracies(probabilities, ground_truth, bins)
    bin_centers = bin_centers_of_mass(probabilities, bin_edges)
    num_examples = np.sum(counts)
    ece = np.sum([(counts[i] / float(num_examples)) * np.sum(np.abs(bin_centers[i] - accuracies[i])) for i in range(bin_centers.size) if counts[i] > 0])
    return ece

def bin_predictions_and_accuracies(probabilities, ground_truth, bins=10):
    #   """A helper function which histograms a vector of probabilities into bins.
    #   Args:
    #     probabilities: A numpy vector of N probabilities assigned to each prediction
    #     ground_truth: A numpy vector of N ground truth labels in {0,1}
    #     bins: Number of equal width bins to bin predictions into in [0, 1], or an
    #       array representing bin edges.
    #   Returns:
    #     bin_edges: Numpy vector of floats containing the edges of the bins
    #       (including leftmost and rightmost).
    #     accuracies: Numpy vector of floats for the average accuracy of the
    #       predictions in each bin.
    #     counts: Numpy vector of ints containing the number of examples per bin.
    #   """
    _validate_probabilities(probabilities)
    _check_rank_nonempty(rank=1,probabilities=probabilities,ground_truth=ground_truth)
        
    if len(probabilities) != len(ground_truth):
        raise ValueError('Probabilies and ground truth must have the same number of elements.')

    if [v for v in ground_truth if v not in [0., 1., True, False]]:
        raise ValueError('Ground truth must contain binary labels {0,1} or {False, True}.')

    if isinstance(bins, int):
        num_bins = bins
    else:
        num_bins = bins.size - 1

    # Ensure probabilities are never 0, since the bins in np.digitize are open on
    # one side.
    probabilities = np.where(probabilities == 0, 1e-8, probabilities)
    _, bin_edges = np.histogram(probabilities, bins=bins)#, range=[0., 1.])
    indices = np.digitize(probabilities, bin_edges, right=True)
    counts = np.array([sum(indices==i) for i in range(1,num_bins+1)])
    aux = []
    for i in range(1,num_bins + 1):
        if len(ground_truth[indices == i])>0:
            aux.append(np.mean(ground_truth[indices == i]))
        else:
            aux.append(np.nan)
    accuracies = np.array(aux)
# leads to warning
#    accuracies = np.array([np.mean(ground_truth[indices == i]) if len(ground_truth[indices == i])>0 else np.nan
#                           for i in range(1, num_bins + 1)])
    return bin_edges, accuracies, counts


def bin_centers_of_mass(probabilities, bin_edges):
    probabilities = np.where(probabilities == 0, 1e-8, probabilities)
    indices = np.digitize(probabilities, bin_edges, right=True)
    aux = []
    for i in range(1,len(bin_edges)):
        if len(probabilities[indices == i])>0:
            aux.append(np.mean(probabilities[indices == i]))
        else:
            aux.append((bin_edges[i]+bin_edges[i-1])/2)
    return np.array(aux)
# leads to warning
#    return np.array([np.mean(probabilities[indices == i])
#                     for i in range(1, len(bin_edges))])

def _validate_probabilities(probabilities, multiclass=False):
    if np.max(probabilities) > 1. or np.min(probabilities) < 0.:
        raise ValueError('All probabilities must be in [0,1].')
    if multiclass and not np.allclose(1, np.sum(probabilities, axis=-1),
                                      atol=1e-5):
        raise ValueError(
                         'Multiclass probabilities must sum to 1 along the last dimension.')

def _check_rank_nonempty(rank, **kwargs):
    for key, array in six.iteritems(kwargs):
        if len(array) <= 1 or array.ndim != rank:
            raise ValueError('%s must be a rank-1 array of length > 1; actual shape is %s.' %
                             (key, array.shape))
