import copy
from abc import ABC, abstractmethod
from collections.abc import Callable

import numpy as np
from matplotlib import pyplot as plt
from numpy.typing import NDArray
from scipy.special import expit
from scipy.stats import norm
from sklearn.mixture import GaussianMixture

import research.wsl_ece.metric.ece as ece_module
from research.wsl_ece.metric.wandb_util import get_test_data_artifact


def integrate(
    integrand: Callable[[NDArray[np.floating]], NDArray[np.floating]],
    lower_bound: float,
    upper_bound: float,
    num_points: int = 1000,
) -> float:
    """
    Numerically integrates the given integrand over the specified bounds using the trapezoidal rule.

    Args:
        integrand (Callable[[NDArray[np.floating]], NDArray[np.floating]]): The integrand function to integrate.
        lower_bound (float): The lower bound of integration.
        upper_bound (float): The upper bound of integration.
        num_points (int): The number of points to use for the integration.

    Returns:
        float: The result of the integration.
    """
    x = np.linspace(lower_bound, upper_bound, num_points)
    y = integrand(x)
    return float(np.trapezoid(y, x))


class SyntheticDistribution(ABC):
    """
    Abstract base class for synthetic distributions used in experiments.
    This class defines the interface for distributions used to calculate TCE, ECE and PU-ECE values in experiments.
    It includes methods for computing probability density functions (PDFs), cumulative distribution functions (CDFs),
    and expected calibration errors (ECEs). ECEs only accept Uniform Mass Binning (UMB) strategy.
    """

    has_real_example: bool
    test_logits: NDArray[np.floating] | None = None
    test_labels: NDArray[np.integer] | None = None

    def __init__(self, seed: int = 42):
        self.rng = np.random.default_rng(seed)

    def get_rng(self, seed: int | None = None) -> np.random.Generator:
        """
        Returns a random number generator with the specified seed.
        If no seed is provided, uses the default random number generator.

        Args:
            seed (int | None): Optional random seed for reproducibility.

        Returns:
            np.random.Generator: A random number generator instance.
        """
        if seed is None:
            return self.rng
        return np.random.default_rng(seed)

    @property
    def prior(self) -> float:
        """
        Returns the prior probability of the positive class.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    @property
    def tce(self) -> float:
        """
        Returns the true calibration error (TCE) of the distribution.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    @abstractmethod
    def sample(self, n_samples: int, seed: int | None = None) -> tuple[NDArray[np.floating], NDArray[np.integer]]:
        """
        Generates samples from the distribution.

        Args:
            n_samples (int): The number of samples to generate.
            seed (int | None): Optional random seed for reproducibility.

        Returns:
            tuple[NDArray[np.floating], NDArray[int]]:
                A tuple containing:
                - An array of scores (float) sampled from the distribution.
                - An array of labels (int) generated based on the scores.
        """
        ...

    @abstractmethod
    def positive_sample(self, n_samples: int, seed: int | None = None) -> NDArray[np.floating]:
        """
        Generates positive samples from the distribution.

        Args:
            n_samples (int): The number of positive samples to generate.
            seed (int | None): Optional random seed for reproducibility.

        Returns:
            NDArray[np.floating]: An array of positive scores sampled from the distribution.
        """
        ...

    @abstractmethod
    def unlabeled_truncated_expectation(self, lower_bound: float, upper_bound: float) -> float:
        """
        Computes the expected value for unlabeled samples over a specified interval.

        Args:
            lower_bound (float): The lower bound of the interval.
            upper_bound (float): The upper bound of the interval.

        Returns:
            float: The expected value for unlabeled samples.
        """
        ...

    @abstractmethod
    def positive_cdf_diff(self, lower_bound: float, upper_bound: float) -> float:
        """
        Computes the difference in cumulative distribution function (CDF) for positive samples.

        Args:
            lower_bound (float): The lower bound of the interval.
            upper_bound (float): The upper bound of the interval.

        Returns:
            float: The difference in CDF for positive samples.
        """
        ...

    @abstractmethod
    def ece(
        self,
        n_samples: int,
        binning_strategy: ece_module.BinningStrategy,
        n_bins: int | None = None,
        seed: int | None = None,
    ) -> float:
        """
        Computes the Expected Calibration Error (ECE) for the distribution.

        Args:
            n_samples (int): The number of samples to use for the computation.
            binning_strategy (ece_module.BinningStrategy): The binning strategy to use for ECE calculation.
            n_bins (int | None): The number of bins to use for ECE calculation. If None, it will be computed based on
            seed (int | None): Optional random seed for reproducibility.

        Returns:
            float: The computed ECE.
        """
        ...

    @abstractmethod
    def pu_ece(
        self,
        n_positive: int | float,
        n_unlabeled: int | float,
        binning_strategy: ece_module.BinningStrategy,
        n_bins: int | None = None,
        seed: int | None = None,
        prior_estimation_error: float = 1.0,
    ) -> float:
        """
        Computes the Positive-Unlabeled Expected Calibration Error (PU-ECE).

        Args:
            n_positive (int | float): The number of positive samples.
            n_unlabeled (int | float): The number of unlabeled samples.
            binning_strategy (ece_module.BinningStrategy): The binning strategy to use for PU-ECE calculation.
            n_bins (int | None): The number of bins to use for PU-ECE calculation. If None, it will be computed based on
            seed (int | None): Optional random seed for reproducibility.
            prior_estimation_error (float): The bias factor for prior estimation.

        Returns:
            float: The computed PU-ECE.
        """
        ...

    def plot_distribution(self, n_samples: int = 100000, seed: int | None = None):
        """
        Plots the distribution of scores and labels.

        Args:
            n_samples (int): The number of samples to generate for plotting.
            seed (int | None): Optional random seed for reproducibility.
        """
        self.rng = self.get_rng(seed)
        if self.has_real_example:
            assert self.test_logits is not None and self.test_labels is not None, (
                "Test logits and labels must be set for plotting real examples."
            )
            logits, labels = self.test_logits, self.test_labels
        else:
            logits, labels = self.sample(n_samples)

        probs = expit(logits)
        n_bins = np.cbrt(n_samples).astype(int)  # Use cube root of n_samples for binning
        fig, ax = plt.subplots(1, 2, figsize=(10, 6))
        ax[0].hist(logits[labels == 1], bins=n_bins, alpha=0.5, label="Positive Samples", color="blue")
        ax[0].hist(logits[labels == 0], bins=n_bins, alpha=0.5, label="Negative Samples", color="red")
        ax[0].set_title("Distribution of Logits")
        ax[0].set_xlabel("Logits")
        ax[0].set_ylabel("Frequency")
        ax[0].legend()
        ax[1].hist(probs[labels == 1], bins=n_bins, alpha=0.5, label="Positive Samples", color="blue")
        ax[1].hist(probs[labels == 0], bins=n_bins, alpha=0.5, label="Negative Samples", color="red")
        ax[1].set_title("Distribution of Probabilities")
        ax[1].set_xlabel("Probabilities")
        ax[1].set_ylabel("Frequency")
        ax[1].legend()
        plt.tight_layout()

        plt.savefig(f"{self.__class__.__name__}_distribution.png", bbox_inches="tight")


def select_components_by_combined_bic(preds, labels, max_components=10, seed: int = 42):
    """
    Fits GMM to positive data and negative data separately,
    and selects the number of components that minimizes the sum of both BICs.
    """
    positive_data = preds[labels == 1]
    negative_data = preds[labels == 0]

    n_components_range = range(1, max_components + 1)
    results = []

    for n_pos in n_components_range:
        for n_neg in n_components_range:
            # Fit GMM to positive data and negative data separately
            if len(positive_data) > n_pos:
                gmm_pos = GaussianMixture(n_components=n_pos, random_state=42)
                gmm_pos.fit(positive_data.reshape(-1, 1))
                bic_pos = gmm_pos.bic(positive_data.reshape(-1, 1))
            else:
                gmm_pos = None
                bic_pos = np.inf
            if len(negative_data) > n_neg:
                gmm_neg = GaussianMixture(n_components=n_neg, random_state=42)
                gmm_neg.fit(negative_data.reshape(-1, 1))
                bic_neg = gmm_neg.bic(negative_data.reshape(-1, 1))
            else:
                gmm_neg = None
                bic_neg = np.inf

            total_bic = bic_pos + bic_neg
            results.append(
                {
                    "n_pos": n_pos,
                    "n_neg": n_neg,
                    "bic_pos": bic_pos,
                    "bic_neg": bic_neg,
                    "total_bic": total_bic,
                    "gmm_pos": gmm_pos,
                    "gmm_neg": gmm_neg,
                }
            )
    best_result = min(results, key=lambda x: x["total_bic"])
    return best_result


def _estimate_gaussian_pdf_1d(x: NDArray[np.floating], mean: float, var: float) -> NDArray[np.floating]:
    """Compute 1D Gaussian pdf values for given x, mean, and variance.

    Args:
        x: shape (N,)
        mean: scalar mean
        var: scalar variance (>0)

    Returns:
        pdf values shape (N,)
    """
    # Avoid division by zero
    var = float(max(var, 1e-12))
    std = np.sqrt(var)
    return norm.pdf(x, loc=mean, scale=std)


def _fit_class_specific_weights_with_fixed_components(
    data: NDArray[np.floating],
    means: NDArray[np.floating],  # shape (K, 1)
    covariances: NDArray[np.floating],  # shape (K, 1, 1)
    init_weights: NDArray[np.floating],  # shape (K,)
    tol: float = 1e-6,
    max_iter: int = 1000,
) -> NDArray[np.floating]:
    """Estimate mixture weights that maximize likelihood for `data` given fixed component params.

    Uses EM updates restricted to mixing coefficients only (means/covariances fixed).

    Args:
        data: 1D array of shape (N,). If N == 0, returns `init_weights`.
        means: component means, shape (K, 1)
        covariances: component covariance matrices, shape (K, 1, 1)
        init_weights: initial simplex weights (>=0, sum=1), shape (K,)
        tol: convergence tolerance on max absolute weight change
        max_iter: maximum iterations

    Returns:
        weights: optimized mixture weights on the simplex, shape (K,)
    """
    K = int(means.shape[0])
    if data.size == 0 or K == 0:
        return init_weights.copy()

    # Precompute component densities a_{ik} = p_k(x_i)
    # Shapes: data: (N,), A: (N, K)
    N = int(data.shape[0])
    A = np.empty((N, K), dtype=float)
    for k in range(K):
        m = float(means[k, 0])
        v = float(covariances[k, 0, 0])
        A[:, k] = _estimate_gaussian_pdf_1d(data, m, v)

    # Initialize weights
    pi = np.clip(init_weights.astype(float), 1e-12, None)
    pi = pi / pi.sum()

    # EM iterations on weights only
    eps = 1e-300
    for _ in range(max_iter):
        # Compute responsibilities r_{ik} = pi_k a_{ik} / sum_j pi_j a_{ij}
        denom = A @ pi  # shape (N,)
        denom = np.clip(denom, eps, None)
        R = (A * pi) / denom[:, None]  # broadcasting, shape (N, K)
        new_pi = R.mean(axis=0)
        # Project onto simplex just in case of numerical drift
        new_pi = np.clip(new_pi, 1e-12, None)
        new_pi = new_pi / new_pi.sum()
        if np.max(np.abs(new_pi - pi)) < tol:
            pi = new_pi
            break
        pi = new_pi

    return pi


def select_components_by_bic_and_fit_weights(
    preds: NDArray[np.floating],
    labels: NDArray[np.integer],
    max_components: int = 10,
    seed: int = 42,
    tol: float = 1e-6,
    max_iter: int = 1000,
):
    """Fit a single GMM on all `preds` by BIC, then fit class-specific mixing weights.

    This differs from `select_components_by_combined_bic`, which fits separate GMMs to
    positive and negative data. Here we:
      1) Fit a GMM to the entire `preds` for K in 1..max_components and select K by BIC.
      2) Keep the selected GMM component means/covariances fixed.
      3) For each class subset (positive_data, negative_data), re-estimate only the
         mixture weights to maximize that subset's likelihood.

    Args:
        preds: 1D array of prediction logits/scores, shape (N,)
        labels: 1D array of {0,1} labels, shape (N,)
        max_components: maximum number of components to consider
        seed: random seed for GMM fitting
        tol: convergence tolerance for weight-only EM
        max_iter: max iterations for weight-only EM

    Returns:
        dict with keys:
            - 'n_components': selected number of components (int)
            - 'bic': BIC of the selected model on all preds (float)
            - 'gmm': fitted GaussianMixture on all preds (GaussianMixture)
            - 'positive_weights': optimized weights for positive_data (NDArray[K])
            - 'negative_weights': optimized weights for negative_data (NDArray[K])
    """
    preds = np.asarray(preds).reshape(-1)
    labels = np.asarray(labels).reshape(-1)
    assert preds.shape[0] == labels.shape[0], "preds and labels must have the same length"

    rng = np.random.default_rng(seed)
    n_components_range = range(1, max_components + 1)
    best = None
    for K in n_components_range:
        if preds.shape[0] <= K:
            continue  # cannot fit more components than samples
        gmm = GaussianMixture(n_components=K, random_state=int(rng.integers(2**32 - 1)), covariance_type="tied")
        gmm.fit(preds.reshape(-1, 1))
        bic = gmm.bic(preds.reshape(-1, 1))
        cand = {"n_components": K, "bic": bic, "gmm": gmm}
        if best is None or bic < best["bic"]:
            best = cand

    if best is None:
        raise ValueError("Failed to fit any GMM model. Check the number of samples vs max_components.")

    gmm = best["gmm"]
    K = gmm.n_components  # type: ignore
    # Extract params
    means = gmm.means_.astype(float)  # (K, 1)
    covariances = gmm.covariances_.astype(float)  # (K, 1, 1)
    init_weights = gmm.weights_.astype(float)  # (K,)

    # Split data by labels
    positive_data = preds[labels == 1]
    negative_data = preds[labels == 0]

    pos_weights = _fit_class_specific_weights_with_fixed_components(
        positive_data, means, covariances, init_weights, tol=tol, max_iter=max_iter
    )
    neg_weights = _fit_class_specific_weights_with_fixed_components(
        negative_data, means, covariances, init_weights, tol=tol, max_iter=max_iter
    )

    return {
        "n_components": K,
        "bic": best["bic"],
        "gmm": gmm,
        "positive_weights": pos_weights,
        "negative_weights": neg_weights,
    }


class MixNMatchDistribution(SyntheticDistribution):
    """
    Represent a distribution that were demonstrated in the Mix-n-Match paper experiment.
    The distribution is defined as follows:
    $p(Y=1)=0.5$, $p(X|Y=1)=\\mathcal{N}(1, 1)$, $p(X|Y=0)=\\mathcal{N}(-1, 1)$,
    $p(Y=1|X=x)=1/(1+exp(-2x))$, $f(x) = \\frac{1}{1+exp(-\\beta_0 - \\beta_1 x)}$.
    $p(Y=1|f(X)=s)=1/(1+exp(2*(\\beta_0 + \\log(1/s - 1))/\\beta_1))$
    """

    has_real_example = False

    def __init__(self, beta_0: float, beta_1: float, seed: int = 42):
        """
        Initializes the MixNMatchDistribution with parameters beta_0 and beta_1, and a random seed.

        Args:
            beta_0 (float): The intercept parameter for the logistic function.
            beta_1 (float): The slope parameter for the logistic function.
            seed (int): Random seed for reproducibility.
        """
        super().__init__(seed=seed)
        self.beta_0 = beta_0
        self.beta_1 = beta_1
        self._x_upper_bound = 20.0
        self._x_lower_bound = -20.0

    @property
    def prior(self) -> float:
        return 0.5

    @property
    def tce(self) -> float:
        def _integrand(x: NDArray[np.floating]) -> NDArray[np.floating]:
            """
            Computes the integrand for the expectation calculation.
            """
            pdf = self.prior * norm.pdf(x, loc=1, scale=1) + (1 - self.prior) * norm.pdf(x, loc=-1, scale=1)
            return abs(expit(2 * x) - expit(self.beta_0 + self.beta_1 * x)) * pdf

        return integrate(
            _integrand, lower_bound=self._x_lower_bound, upper_bound=self._x_upper_bound, num_points=10_000
        )

    def sample(self, n_samples: int, seed: int | None = None) -> tuple[NDArray[np.floating], NDArray[np.integer]]:
        rng = self.get_rng(seed)
        labels = rng.binomial(1, self.prior, size=n_samples)
        x = np.where(labels == 1, rng.normal(1, 1, size=n_samples), rng.normal(-1, 1, size=n_samples)).astype(
            np.float64
        )
        return self.logit_score(x), labels

    def positive_sample(self, n_samples: int, seed: int | None = None) -> NDArray[np.floating]:
        rng = self.get_rng(seed)
        x = rng.normal(1, 1, size=n_samples)
        return self.logit_score(x)

    def cdf(self, x: NDArray[np.floating] | float) -> NDArray[np.floating]:
        return norm.cdf(x, loc=1, scale=1) * self.prior + norm.cdf(x, loc=-1, scale=1) * (1 - self.prior)

    def quantiles(self, n_bins: int):
        """Calculate quantiles for binning

        Use binary search to find quantiles for the distribution.
        """
        quantiles = np.linspace(0, 1, n_bins + 1)
        quantile_values = np.zeros(n_bins + 1)
        for i in range(n_bins + 1):
            # Use binary search to find the quantile value
            low, high = self._x_lower_bound, self._x_upper_bound
            while high - low > 1e-6:
                mid = (low + high) / 2
                cdf_value = self.cdf(mid)
                if cdf_value < quantiles[i]:
                    low = mid
                else:
                    high = mid
            quantile_values[i] = (low + high) / 2
        return quantile_values

    def unlabeled_truncated_expectation(self, lower_bound: float, upper_bound: float) -> float:
        if upper_bound == np.inf:
            x_upper_bound = self._x_upper_bound
        else:
            x_upper_bound = self.inverse_logit_fn(upper_bound)
        if lower_bound == -np.inf:
            x_lower_bound = self._x_lower_bound
        else:
            x_lower_bound = self.inverse_logit_fn(lower_bound)

        def _integrand(x: NDArray[np.floating]) -> NDArray[np.floating]:
            """
            Computes the integrand for the expectation calculation.
            """
            pdf = self.prior * norm.pdf(x, loc=1, scale=1) + (1 - self.prior) * norm.pdf(x, loc=-1, scale=1)
            return pdf * expit(self.beta_0 + self.beta_1 * x)

        num_points_recommended = max(
            int(10_000 * ((x_upper_bound - x_lower_bound) / (self._x_upper_bound - self._x_lower_bound))), 2
        )

        return integrate(
            _integrand,
            lower_bound=x_lower_bound,
            upper_bound=x_upper_bound,
            num_points=num_points_recommended,
        )

    def positive_cdf_diff(self, lower_bound: float, upper_bound: float) -> float:
        if upper_bound != np.inf:
            upper_bound = self.inverse_logit_fn(upper_bound)
        if lower_bound != -np.inf:
            lower_bound = self.inverse_logit_fn(lower_bound)
        lower_cdf = norm.cdf(lower_bound, loc=1, scale=1)
        upper_cdf = norm.cdf(upper_bound, loc=1, scale=1)
        return float(upper_cdf - lower_cdf)

    def logit_score(self, x: NDArray[np.floating]) -> NDArray[np.floating]:
        """
        Computes the logit score for the given input.

        Args:
            x (NDArray[np.floating]): The input value(s) to apply the logit score.

        Returns:
            NDArray[np.floating]: The logit of the input value(s).
        """
        return self.beta_0 + self.beta_1 * x

    def inverse_logit_fn(self, logit: float) -> float:
        """
        Computes the inverse logit function for the given logits.

        Args:
            logits (NDArray[np.floating]): The input value(s) to apply the inverse logit function.

        Returns:
            NDArray[np.floating]: The inverse logit of the input value(s).
        """
        return (logit - self.beta_0) / self.beta_1

    def ece(
        self,
        n_samples: int,
        binning_strategy: ece_module.BinningStrategy,
        n_bins: int | None = None,
        seed: int | None = None,
    ) -> float:
        self.rng = self.get_rng(seed)
        # generate random scores and labels
        logits, labels = self.sample(n_samples)
        if n_bins is None:
            n_bins = ece_module.get_recommended_n_bins_for_ece(n_samples)
        return ece_module.calculate_ece(logits, labels, n_bins=n_bins, strategy=binning_strategy)

    def pu_ece(
        self,
        n_positive: int | float,
        n_unlabeled: int | float,
        binning_strategy: ece_module.BinningStrategy,
        n_bins: int | None = None,
        seed: int | None = None,
        prior_estimation_error: float = 1.0,
    ) -> float:
        self.rng = self.get_rng(seed)
        if n_positive == np.inf:
            positive_logits = None
        else:
            positive_logits = self.positive_sample(int(n_positive))
        if n_unlabeled == np.inf:
            unlabeled_logits = None
        else:
            unlabeled_logits, _ = self.sample(int(n_unlabeled))

        if n_bins is None:
            n_bins = ece_module.get_recommended_n_bins_for_pu_ece(self.prior, n_positive, n_unlabeled)
        bin_edges = None
        if n_unlabeled == np.inf:
            if binning_strategy == ece_module.BinningStrategy.UMB:
                bin_edges = self.quantiles(n_bins)
            else:
                bin_edges = ece_module.get_logit_bin_edges_for_uwb(n_bins)
        ece = ece_module.calculate_pu_ece(
            positive_logits=positive_logits,
            unlabeled_logits=unlabeled_logits,
            prior=self.prior * prior_estimation_error,
            n_bins=n_bins,
            n_positive=n_positive,
            n_unlabeled=n_unlabeled,
            binning_strategy=binning_strategy,
            positive_cdf_diff=self.positive_cdf_diff if n_positive == np.inf else None,
            unlabeled_truncated_expectation=self.unlabeled_truncated_expectation if n_unlabeled == np.inf else None,
            bin_edges=bin_edges,
        )
        return float(ece)


class PredictedLogitGMMDistribution(SyntheticDistribution, ABC):
    """
    Represents a Gaussian Mixture Model (GMM) distribution for the predicted logits.
    """

    MODEL_ARTIFACT_NAME: str
    has_real_example = True

    def __init__(self, seed: int = 42, gmm_fit_mode: str = "separate", max_components: int = 10):
        """Initialize the prediction distribution from model artifacts.

        Args:
            seed: Random seed for reproducibility.
            gmm_fit_mode: How to fit GMMs for positive/negative densities.
                - 'separate' (default): Fit GMMs to positive and negative data separately using
                  `select_components_by_combined_bic`.
                - 'shared': Fit a single GMM to all preds by BIC and then optimize class-specific
                  weights using `select_components_by_bic_and_fit_weights`, cloning the shared GMM
                  and replacing weights for pos/neg.
            max_components: Maximum number of components to consider in model selection.
        """
        super().__init__(seed=seed)
        self.test_logits, self.test_labels = get_test_data_artifact(self.MODEL_ARTIFACT_NAME)
        print(f"Accuracy: {np.mean(self.test_labels == (self.test_logits > 0.0).astype(np.int64)) * 100:.2f}%")
        self._prior = float(np.mean(self.test_labels))
        if gmm_fit_mode not in {"separate", "shared"}:
            raise ValueError(f"gmm_fit_mode must be 'separate' or 'shared', got: {gmm_fit_mode}")

        if gmm_fit_mode == "separate":
            results = select_components_by_combined_bic(
                self.test_logits, self.test_labels, max_components=max_components, seed=seed
            )
            print(
                f"Mode=separate: Selected GMM with {results['n_pos']} positive components and "
                f"{results['n_neg']} negative components, with BICs: {results['bic_pos']} (positive), "
                f"{results['bic_neg']} (negative), and total BIC: {results['total_bic']}"
            )
            self.gmm_pos: GaussianMixture = results["gmm_pos"]
            self.gmm_neg: GaussianMixture = results["gmm_neg"]
        else:
            results_shared = select_components_by_bic_and_fit_weights(
                self.test_logits, self.test_labels, max_components=max_components, seed=seed
            )
            print(
                f"Mode=shared: Selected {results_shared['n_components']} components by "
                f"BIC={results_shared['bic']:.3f}. Fitted class-specific weights."
            )
            shared_gmm: GaussianMixture = results_shared["gmm"]
            pos_weights = results_shared["positive_weights"]
            neg_weights = results_shared["negative_weights"]

            # Clone the shared GMM and override weights for each class
            gmm_pos = copy.deepcopy(shared_gmm)
            gmm_neg = copy.deepcopy(shared_gmm)
            # Ensure weights are valid simplex
            gmm_pos.weights_ = (pos_weights / np.sum(pos_weights)).astype(float)
            gmm_neg.weights_ = (neg_weights / np.sum(neg_weights)).astype(float)

            self.gmm_pos = gmm_pos
            self.gmm_neg = gmm_neg
        self._logit_upper_bound = float(np.max(self.test_logits)) * 2
        self._logit_lower_bound = float(np.min(self.test_logits)) * 2
        if self.gmm_pos is None or self.gmm_neg is None:
            raise ValueError(
                f"Failed to fit GMM to positive or negative data. "
                f"Positive GMM: {self.gmm_pos}, Negative GMM: {self.gmm_neg}"
            )

        # TCE calculation.
        def _integrand(logit_val: NDArray[np.floating]) -> NDArray[np.floating]:
            """
            Computes the integrand for the TCE calculation.
            """
            pos_pdf = np.exp(self.gmm_pos.score_samples(logit_val.reshape(-1, 1)))
            neg_pdf = np.exp(self.gmm_neg.score_samples(logit_val.reshape(-1, 1)))
            pdf = self.prior * pos_pdf + (1 - self.prior) * neg_pdf
            return np.abs(self.prior * pos_pdf - expit(logit_val) * pdf)

        self._tce = integrate(
            _integrand,
            lower_bound=self._logit_lower_bound,
            upper_bound=self._logit_upper_bound,
            num_points=10_000,
        )

        assert 0 <= self._tce <= 1, f"TCE should be in [0, 1], but got {self._tce}."

    @property
    def prior(self) -> float:
        return self._prior

    @property
    def tce(self) -> float:
        """
        Computes the True Calibration Error (TCE) for the distribution.
        The TCE is defined as the expected value of the absolute difference between the predicted probability and the
        true label.
        Returns:
            float: The computed TCE.
        """
        return self._tce

    def sample(self, n_samples: int, seed: int | None = None) -> tuple[NDArray[np.floating], NDArray[np.integer]]:
        """
        Generates samples from the GMM distribution.

        Args:
            n_samples (int): The number of samples to generate.
            seed (int | None): Optional random seed for reproducibility.

        Returns:
            tuple[NDArray[np.floating], NDArray[int]]:
                A tuple containing:
                - An array of scores (float) sampled from the GMM distribution.
                - An array of labels (int) generated based on the scores.
        """
        rng = self.get_rng(seed)
        self.gmm_pos.random_state = rng.integers(2**32 - 1)  # type: ignore
        self.gmm_neg.random_state = rng.integers(2**32 - 1)  # type: ignore

        n_positive = rng.binomial(n_samples, self.prior)
        n_negative = n_samples - n_positive
        if n_negative > 0:
            neg_samples = self.gmm_neg.sample(n_negative)[0].flatten()
        else:
            neg_samples = np.array([])
        if n_positive > 0:
            pos_samples = self.gmm_pos.sample(n_positive)[0].flatten()
        else:
            pos_samples = np.array([])
        scores = np.concatenate([pos_samples, neg_samples]).astype(np.float64)
        labels = np.concatenate([np.ones(n_positive), np.zeros(n_negative)]).astype(np.int64)
        return scores, labels

    def positive_sample(self, n_samples: int, seed: int | None = None) -> NDArray[np.floating]:
        """
        Generates positive samples from the GMM distribution.

        Args:
            n_samples (int): The number of positive samples to generate.
            seed (int | None): Optional random seed for reproducibility.

        Returns:
            NDArray[np.floating]: An array of positive scores sampled from the GMM distribution.
        """
        rng = self.get_rng(seed)
        self.gmm_pos.random_state = rng.integers(2**32 - 1)  # type: ignore
        pos_samples = self.gmm_pos.sample(n_samples)[0].flatten()
        return pos_samples.astype(np.float64)

    def positive_pdf(self, logits: NDArray[np.floating] | float) -> NDArray[np.floating]:
        """
        Computes the probability density function (PDF) for the positive class.

        Args:
            logits (NDArray[np.floating] | float): An array of logit values.
        Returns:
            NDArray[np.floating]: The PDF values for the positive class.
        """
        pdf_vals = np.zeros_like(logits)
        for i in range(self.gmm_pos.n_components):  # type: ignore
            weight: float = self.gmm_pos.weights_[i]  # type: ignore
            mean: float = self.gmm_pos.means_[i, 0]  # type: ignore
            std: float = np.sqrt(self.gmm_pos.covariances_[i, 0, 0])  # type: ignore
            pos_pdf = norm.pdf(logits, loc=mean, scale=std)
            pdf_vals += weight * pos_pdf
        return pdf_vals

    def negative_pdf(self, logits: NDArray[np.floating] | float) -> NDArray[np.floating]:
        """
        Computes the probability density function (PDF) for the negative class.

        Args:
            logits (NDArray[np.floating] | float): An array of logit values.
        Returns:
            NDArray[np.floating]: The PDF values for the negative class.
        """
        pdf_vals = np.zeros_like(logits)
        for i in range(self.gmm_neg.n_components):  # type: ignore
            weight: float = self.gmm_neg.weights_[i]  # type: ignore
            mean: float = self.gmm_neg.means_[i, 0]  # type: ignore
            std: float = np.sqrt(self.gmm_neg.covariances_[i, 0, 0])  # type: ignore
            neg_pdf = norm.pdf(logits, loc=mean, scale=std)
            pdf_vals += weight * neg_pdf
        return pdf_vals

    def pdf(self, logits: NDArray[np.floating] | float) -> NDArray[np.floating]:
        """
        Computes the probability density function (PDF) for the given logits.

        Args:
            logits (NDArray[np.floating] | float): An array of logit values.
        Returns:
            NDArray[np.floating]: The PDF values for the given logits.
        """
        pdf_vals = self.prior * self.positive_pdf(logits) + (1 - self.prior) * self.negative_pdf(logits)
        return pdf_vals

    def positive_posterior(self, logits: NDArray[np.floating] | float) -> NDArray[np.floating]:
        """
        Computes the posterior probability of the positive class given the logits using Bayes' theorem.

        Args:
            logits (NDArray[np.floating] | float): An array of logit values.

        Returns:
            NDArray[np.floating]: The posterior probabilities of the positive class for the given logits.
        """
        pos_pdf = self.positive_pdf(logits)
        neg_pdf = self.negative_pdf(logits)
        total_pdf = self.prior * pos_pdf + (1 - self.prior) * neg_pdf
        with np.errstate(divide="ignore", invalid="ignore"):
            posterior_probs = np.where(total_pdf > 0, (self.prior * pos_pdf) / total_pdf, 0.0)
        return posterior_probs

    def cdf(self, logits: NDArray[np.floating] | float) -> NDArray[np.floating]:
        """
        Computes the cumulative distribution function (CDF) for the given logits.

        Args:
            logits (NDArray[np.floating] | float): An array of logit values.

        Returns:
            NDArray[np.floating]: The CDF values for the given logits.
        """
        cdf_vals = np.zeros_like(logits)
        for i in range(self.gmm_pos.n_components):  # type: ignore
            weight: float = self.gmm_pos.weights_[i]  # type: ignore
            mean: float = self.gmm_pos.means_[i, 0]  # type: ignore
            std: float = np.sqrt(self.gmm_pos.covariances_[i, 0, 0])  # type: ignore
            pos_cdf = norm.cdf(logits, loc=mean, scale=std)
            cdf_vals += self.prior * weight * pos_cdf
        for i in range(self.gmm_neg.n_components):  # type: ignore
            weight: float = self.gmm_neg.weights_[i]  # type: ignore
            mean: float = self.gmm_neg.means_[i, 0]  # type: ignore
            std: float = np.sqrt(self.gmm_neg.covariances_[i, 0, 0])  # type: ignore
            neg_cdf = norm.cdf(logits, loc=mean, scale=std)
            cdf_vals += (1 - self.prior) * weight * neg_cdf
        return cdf_vals

    def quantiles(self, n_bins: int):
        """Calculate quantiles for binning

        Use binary search to find quantiles for the distribution.
        """
        quantiles = np.linspace(0, 1, n_bins + 1)
        quantile_values = np.zeros(n_bins + 1)
        for i in range(n_bins + 1):
            # Use binary search to find the quantile value
            low, high = self._logit_lower_bound, self._logit_upper_bound
            while high - low > 1e-6:
                mid = (low + high) / 2
                cdf_value = self.cdf(mid)
                if cdf_value < quantiles[i]:
                    low = mid
                else:
                    high = mid
            quantile_values[i] = (low + high) / 2
        return quantile_values

    def unlabeled_truncated_expectation(self, lower_bound: float, upper_bound: float) -> float:
        if upper_bound == np.inf:
            upper_bound = self._logit_upper_bound
        if lower_bound == -np.inf:
            lower_bound = self._logit_lower_bound

        def _integrand(logit_val: NDArray[np.floating]) -> NDArray[np.floating]:
            """
            Computes the integrand for the expectation calculation.
            """
            pos_pdf = np.exp(self.gmm_pos.score_samples(logit_val.reshape(-1, 1)))
            neg_pdf = np.exp(self.gmm_neg.score_samples(logit_val.reshape(-1, 1)))
            pdf = self.prior * pos_pdf + (1 - self.prior) * neg_pdf
            return pdf * expit(logit_val)

        num_points_recommended = max(
            int(10_000 * ((upper_bound - lower_bound) / (self._logit_upper_bound - self._logit_lower_bound))), 2
        )

        return integrate(
            _integrand,
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            num_points=num_points_recommended,
        )

    def positive_cdf_diff(self, lower_bound: float, upper_bound: float) -> float:
        cdf_diff = 0.0
        for i in range(self.gmm_pos.n_components):  # type: ignore
            weight: float = self.gmm_pos.weights_[i]  # type: ignore
            mean: float = self.gmm_pos.means_[i, 0]  # type: ignore
            std: float = np.sqrt(self.gmm_pos.covariances_[i, 0, 0])  # type: ignore
            lower_cdf = norm.cdf(lower_bound, loc=mean, scale=std)
            upper_cdf = norm.cdf(upper_bound, loc=mean, scale=std)
            cdf_diff += weight * (upper_cdf - lower_cdf)
        return float(cdf_diff)

    def ece(
        self,
        n_samples: int,
        binning_strategy: ece_module.BinningStrategy,
        n_bins: int | None = None,
        seed: int | None = None,
    ) -> float:
        self.rng = self.get_rng(seed)
        # generate random scores and labels
        logits, labels = self.sample(n_samples)
        if n_bins is None:
            n_bins = ece_module.get_recommended_n_bins_for_ece(n_samples)
        return ece_module.calculate_ece(logits, labels, n_bins=n_bins, strategy=binning_strategy)

    def pu_ece(
        self,
        n_positive: int | float,
        n_unlabeled: int | float,
        binning_strategy: ece_module.BinningStrategy,
        n_bins: int | None = None,
        seed: int | None = None,
        prior_estimation_error: float = 1.0,
    ) -> float:
        self.rng = self.get_rng(seed)
        if n_positive == np.inf:
            positive_logits = None
        else:
            positive_logits = self.positive_sample(int(n_positive))
        if n_unlabeled == np.inf:
            unlabeled_logits = None
        else:
            unlabeled_logits = self.sample(int(n_unlabeled))[0]
        if n_bins is None:
            n_bins = ece_module.get_recommended_n_bins_for_pu_ece(self.prior, n_positive, n_unlabeled)
        bin_edges = None
        if n_unlabeled == np.inf:
            if binning_strategy == ece_module.BinningStrategy.UMB:
                bin_edges = self.quantiles(n_bins)
            else:
                bin_edges = ece_module.get_logit_bin_edges_for_uwb(n_bins)
        ece = ece_module.calculate_pu_ece(
            positive_logits=positive_logits,
            unlabeled_logits=unlabeled_logits,
            prior=self.prior * prior_estimation_error,
            n_bins=n_bins,
            n_positive=n_positive,
            n_unlabeled=n_unlabeled,
            binning_strategy=binning_strategy,
            positive_cdf_diff=self.positive_cdf_diff if n_positive == np.inf else None,
            unlabeled_truncated_expectation=self.unlabeled_truncated_expectation if n_unlabeled == np.inf else None,
            bin_edges=bin_edges,
        )
        try:
            return float(ece)
        except Exception as e:
            raise ValueError(
                f"Failed to calculate PU-ECE for n_positive={n_positive}, n_unlabeled={n_unlabeled}.\n"
                f"Error: {e}\n"
                f"ECE: {ece}"
            ) from e


class MNISTPredictionDistribution(PredictedLogitGMMDistribution):
    """
    Represents the MNIST prediction distribution using a Gaussian Mixture Model (GMM).
    This class is used to compute TCE, ECE, and PU-ECE values based on the predicted logits from a trained model.
    """

    MODEL_ARTIFACT_NAME = (
        "wsl-ece/predictions_table_mnist_MLP_300_300_10000_256_100_0_001_sigmoid_False_False_cross_entropy_42:latest"
    )


class CIFAR10PredictionDistribution(PredictedLogitGMMDistribution):
    """
    Represents the CIFAR10 prediction distribution using a Gaussian Mixture Model (GMM).
    This class is used to compute TCE, ECE, and PU-ECE values based on the predicted logits from a trained model.
    """

    MODEL_ARTIFACT_NAME = (
        "wsl-ece/predictions_table_cifar10_ResNet18_10000_256_100_1e_05_sigmoid_False_False_cross_entropy_42:latest"
    )


class DDI2013PredictionDistribution(PredictedLogitGMMDistribution):
    """
    Represents the DDI2010 prediction distribution using a Gaussian Mixture Model (GMM).
    This class is used to compute TCE, ECE, and PU-ECE values based on the predicted logits from a trained model.
    """

    MODEL_ARTIFACT_NAME = (
        "wsl-ece/predictions_table_ddi2013_RBertClassifier_2000_64_10_2e_05_cross_entropy_False_False_42_pu:latest"
    )
