from collections.abc import Callable
from enum import Enum
from math import ceil

import numpy as np
from numpy.typing import NDArray
from scipy.special import expit, logit


class BinningStrategy(Enum):
    UMB = "umb"  # Uniform Mass Binning
    UWB = "uwb"  # Uniform Width Binning


def get_recommended_n_bins_for_ece(n_samples: int) -> int:
    return ceil(np.cbrt(n_samples))


def get_recommended_n_bins_for_pu_ece(
    prior: float, n_positive_samples: int | float, n_unlabeled_samples: int | float
) -> int:
    return ceil(1 / np.cbrt(prior**2 / n_positive_samples + 1 / n_unlabeled_samples))


def get_logit_bin_edges_for_uwb(n_bins: int) -> NDArray[np.floating]:
    """
    Get the bin edges for uniform width binning in logit space.

    Args:
        n_bins (int): The number of bins.

    Returns:
        NDArray[float]: The bin edges.
    """
    prob_bin_edges = np.linspace(0, 1, n_bins + 1)
    logit_bin_edges = logit(prob_bin_edges)
    logit_bin_edges[0] = -np.inf  # Ensure the first edge is -inf
    logit_bin_edges[-1] = np.inf  # Ensure the last edge is inf
    return logit_bin_edges


def get_logit_bin_edges(
    logits: NDArray[np.floating], n_bins: int, strategy=BinningStrategy.UMB
) -> NDArray[np.floating]:
    """
    Get the bin edges for the given logits using the specified binning strategy.

    Args:
        logits (NDArray[np.floating]): The input logits.
        n_bins (int): The number of bins.
        strategy (BinningStrategy): The binning strategy to use.

    Returns:
        NDArray[float]: The bin edges.
    """
    match strategy:
        case BinningStrategy.UMB:
            logit_bin_edges = np.quantile(logits, np.linspace(0, 1, n_bins + 1))  # type: ignore
            logit_bin_edges[0] = -np.inf  # Ensure the first edge is -inf
            logit_bin_edges[-1] = np.inf  # Ensure the last edge is inf
        case BinningStrategy.UWB:
            logit_bin_edges = get_logit_bin_edges_for_uwb(n_bins)
    return logit_bin_edges


def calculate_ece(
    logits: NDArray[np.floating],
    labels: NDArray[np.integer],
    n_bins: int,
    strategy=BinningStrategy.UMB,
    bin_edges: NDArray[np.floating] | None = None,
) -> float:
    """
    Calculate the Expected Calibration Error (ECE) for the given logits and labels.

    Args:
        logits (NDArray[float]): The input logits.
        labels (NDArray[int]): The ground truth labels.
        n_bins (int): The number of bins to use.
        strategy (BinningStrategy): The binning strategy to use.
        bin_edges (NDArray[float] | None): Precomputed bin edges. If None, they will be computed based on logits.

    Returns:
        float: The calculated ECE.
    """
    if bin_edges is None:
        bin_edges = get_logit_bin_edges(logits, n_bins, strategy)
    bin_indices = np.clip(np.digitize(logits, bin_edges) - 1, 0, n_bins - 1)
    bin_counts = np.bincount(bin_indices, minlength=n_bins)
    bin_mean_logits = np.bincount(bin_indices, weights=expit(logits), minlength=n_bins) / bin_counts
    bin_mean_labels = np.bincount(bin_indices, weights=labels, minlength=n_bins) / bin_counts
    bin_means = np.nan_to_num(bin_mean_logits, nan=0.0)
    bin_accs = np.nan_to_num(bin_mean_labels, nan=0.0)
    bin_masses = bin_counts / len(logits)  # Normalize bin counts to get the mass of each bin
    bin_masses = np.nan_to_num(bin_masses, nan=0.0)
    # Calculate ECE
    ece = np.sum(np.abs(bin_means - bin_accs) * bin_masses)
    return ece


def calculate_pu_ece(
    positive_logits: NDArray[np.floating] | None,
    unlabeled_logits: NDArray[np.floating] | None,
    prior: float,
    n_bins: int,
    n_positive: int | float,
    n_unlabeled: int | float,
    binning_strategy: BinningStrategy = BinningStrategy.UMB,
    positive_cdf_diff: Callable | None = None,
    unlabeled_truncated_expectation: Callable | None = None,
    bin_edges: NDArray[np.floating] | None = None,
) -> float:
    """
    Calculate the Positive-Unlabeled Expected Calibration Error (PU-ECE).

    Args:
        positive_logits (NDArray[np.floating]): The logits for positive samples.
        unlabeled_logits (NDArray[np.floating]): The logits for unlabeled samples.
        For estimating binning edges, the logits must be specified.
        prior (float): The prior probability of the positive class.
        n_bins (int): The number of bins to use.
        n_positive (int | float): The number of positive samples or np.inf
            if the number of positive samples is infinite.
        n_unlabeled (int | float): The number of unlabeled samples or np.inf
            if the number of unlabeled samples is infinite.
        binning_strategy (BinningStrategy): The binning strategy to use.
        positive_cdf_diff (callable, optional): Function to calculate the CDF difference for positive samples.
        unlabeled_truncated_expectation (callable, optional): Function to calculate the truncated expectation for
            unlabeled samples.
        bin_edges (NDArray[np.floating] | None): Precomputed bin edges. If None, they will be computed based on
            logits.

    Returns:
        float: The calculated PU-ECE.
    """
    if (not (isinstance(n_positive, int) and n_positive > 0) and n_positive != np.inf) or (
        not (isinstance(n_unlabeled, int) and n_unlabeled > 0) and n_unlabeled != np.inf
    ):
        raise TypeError("n_positive and n_unlabeled must be positive integers or infinity.")

    if bin_edges is None:
        if unlabeled_logits is None:
            raise ValueError("unlabeled_logits must be provided to compute bin edges.")
        bin_edges = get_logit_bin_edges(logits=unlabeled_logits, n_bins=n_bins, strategy=binning_strategy)
    if n_positive == np.inf:
        if positive_cdf_diff is None:
            raise ValueError("positive_cdf_diff must be provided if n_positive is infinity.")
        bin_positive_probs = np.array([positive_cdf_diff(bin_edges[i], bin_edges[i + 1]) for i in range(n_bins)])
    else:
        if positive_logits is None:
            raise ValueError("positive_logits must be provided if n_positive is not infinity.")
        positive_bins = np.clip(np.digitize(positive_logits, bin_edges) - 1, 0, n_bins - 1)
        bin_positive_probs = np.bincount(positive_bins, minlength=n_bins) / n_positive
    if n_unlabeled == np.inf:
        if unlabeled_truncated_expectation is None:
            raise ValueError("unlabeled_truncated_expectation must be provided if n_unlabeled is infinity.")
        bin_unlabeled_expectations = np.array(
            [unlabeled_truncated_expectation(bin_edges[i], bin_edges[i + 1]) for i in range(n_bins)]
        )
    else:
        if unlabeled_logits is None:
            raise ValueError("unlabeled_logits must be provided if n_unlabeled is not infinity.")
        unlabeled_bins = np.clip(np.digitize(unlabeled_logits, bin_edges) - 1, 0, n_bins - 1)
        bin_unlabeled_expectations = (
            np.bincount(unlabeled_bins, weights=expit(unlabeled_logits), minlength=n_bins) / n_unlabeled
        )
    return np.sum(np.abs(prior * bin_positive_probs - bin_unlabeled_expectations))
