import torch
from torch import nn
from torch.nn import functional as F
from scipy import optimize
import numpy as np

from tqdm import tqdm


def get_brier(confidences, ground_truth):
    # Compute Brier
    brier = np.zeros(confidences.shape)
    brier[ground_truth] = (1 - confidences[ground_truth]) ** 2
    brier[np.logical_not(ground_truth)] = (
        confidences[np.logical_not(ground_truth)]
    ) ** 2
    brier = np.mean(brier)
    return brier


def bin_confidences_and_accuracies(confidences, ground_truth, bin_edges, indices):
    i = np.arange(0, bin_edges.size - 1)
    aux = indices == i.reshape((-1, 1))
    counts = aux.sum(axis=1)

    # Initialize bin_accuracy and bin_confidence with zeros
    bin_accuracy = np.zeros(counts.shape, dtype=float)
    bin_confidence = np.zeros(counts.shape, dtype=float)

    # Calculate weights. Weights for empty bins will be 0.
    # Avoid division by zero for total_counts if all counts are zero (e.g. empty input)
    total_counts = np.sum(counts)
    if total_counts > 0:
        weights = counts / total_counts
    else:
        weights = np.zeros(counts.shape, dtype=float)

    correct = np.logical_and(aux, ground_truth).sum(axis=1)

    # Create matrix 'a' for confidence sum calculation per bin
    a = np.repeat(confidences.reshape(1, -1), bin_edges.size - 1, axis=0)
    a[np.logical_not(aux)] = (
        0  # Mask out confidences not belonging to the current bin row
    )
    sum_confidences_per_bin = a.sum(axis=1)

    # Calculate accuracy and confidence only for non-empty bins
    non_empty_bins_mask = counts > 0
    bin_accuracy[non_empty_bins_mask] = (
        correct[non_empty_bins_mask] / counts[non_empty_bins_mask]
    )
    bin_confidence[non_empty_bins_mask] = (
        sum_confidences_per_bin[non_empty_bins_mask] / counts[non_empty_bins_mask]
    )

    return weights, bin_accuracy, bin_confidence


def get_ece(confidences, ground_truth, nbins):
    # Repeated code from determine edges. Here it is okay if the bin edges are not unique defined
    confidences_sorted = confidences.copy()
    confidences_index = confidences.argsort()
    confidences_sorted = confidences_sorted[confidences_index]
    aux = np.linspace(0, len(confidences_sorted) - 1, nbins + 1).astype(int) + 1
    bin_indices = np.zeros(len(confidences_sorted)).astype(int)
    bin_indices[: aux[1]] = 0
    for i in range(1, len(aux) - 1):
        bin_indices[aux[i] : aux[i + 1]] = i
    bin_edges = np.zeros(nbins + 1)
    for i in range(0, nbins - 1):
        bin_edges[i + 1] = np.mean(
            np.concatenate(
                (
                    confidences_sorted[bin_indices == i][
                        confidences_sorted[bin_indices == i]
                        == max(confidences_sorted[bin_indices == i])
                    ],
                    confidences_sorted[bin_indices == (i + 1)][
                        confidences_sorted[bin_indices == (i + 1)]
                        == min(confidences_sorted[bin_indices == (i + 1)])
                    ],
                )
            )
        )
    bin_edges[0] = 0
    bin_edges[-1] = 1
    bin_indices = bin_indices[np.argsort(confidences_index)]

    weights, bin_accuracy, bin_confidence = bin_confidences_and_accuracies(
        confidences, ground_truth, bin_edges, bin_indices
    )
    ece = np.dot(weights, np.abs(bin_confidence - bin_accuracy))
    return ece


def get_ece_equal_width(confidences, ground_truth, nbins):
    """
    Compute ECE with equal-width bins.
    Bin edges are linearly spaced from 0 to 1.
    """
    bin_edges = np.linspace(0, 1, nbins + 1)

    # For the last bin, make sure it includes 1.0
    bin_edges[-1] = bin_edges[-1] + 1e-6

    # Digitize confidences into bins
    # Subtract 1 because digitize returns 1-based indices
    bin_indices = np.digitize(
        confidences, bin_edges[1:-1]
    )  # Exclude 0 and 1 from bin edges for digitize logic

    weights, bin_accuracy, bin_confidence = bin_confidences_and_accuracies(
        confidences, ground_truth, bin_edges, bin_indices
    )
    ece = np.dot(weights, np.abs(bin_confidence - bin_accuracy))
    return ece


def compute_scores(confidences, ground_truth, nbins):

    # Compute ECE (equal mass)
    ece_eq_mass = get_ece(confidences, ground_truth, nbins)

    # Compute ECE (equal width)
    ece_eq_width = get_ece_equal_width(confidences, ground_truth, nbins)

    # Compute Brier
    brier = get_brier(confidences, ground_truth)

    return ece_eq_mass, ece_eq_width, brier


def compute_classwise_scores(softmaxes, labels, nbins):
    total = 0
    ece_eq_mass = 0
    ece_eq_width = 0
    brier = 0
    for k in range(softmaxes.shape[1]):
        select = labels == k
        n_k = torch.sum(select).item()
        if n_k == 0:  # Handle cases where a class might not be present in a batch/split
            continue
        ece_temp_mass, ece_temp_width, brier_temp = compute_scores(
            softmaxes[:, k].numpy(), select.numpy(), nbins
        )
        ece_eq_mass += n_k * ece_temp_mass
        ece_eq_width += n_k * ece_temp_width
        brier += n_k * brier_temp
        total += n_k

    if total == 0:  # If no samples were processed (e.g., empty input)
        return 0.0, 0.0, 0.0

    ece_eq_mass /= total
    ece_eq_width /= total
    brier /= total
    return ece_eq_mass, ece_eq_width, brier


def get_uncertainty_measures(softmaxes, labels):
    results = {}

    # Compute Accuracy
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)
    results["acc"] = accuracies.float().mean().item()

    # Top-1 ECEs (equal mass & equal width) and Brier Score
    confidences_np = confidences.numpy()
    accuracies_np = accuracies.numpy()
    (
        results["top1_eq_mass"],
        results["top1_eq_width"],
        results["top1_brier"],
    ) = compute_scores(confidences_np, accuracies_np, 15)

    # Classwise ECEs (equal mass & equal width) and Brier Score
    (
        results["cw_eq_mass"],
        results["cw_eq_width"],
        results["cw_brier"],
    ) = compute_classwise_scores(softmaxes, labels, 15)

    # NLL (Negative Log-Likelihood)
    nll_criterion = nn.NLLLoss(reduction="none")
    results["nll"] = torch.mean(nll_criterion(torch.log(softmaxes), labels)).item()

    ## Transform to onehot encoded labels
    labels_onehot = torch.FloatTensor(softmaxes.shape[0], softmaxes.shape[1])
    labels_onehot.zero_()
    labels_onehot.scatter_(1, labels.long().view(len(labels), 1), 1)
    results["brier"] = torch.mean(
        torch.sum((softmaxes - labels_onehot) ** 2, dim=1, keepdim=True)
    ).item()

    return results
