"""
Evaluation Metrics for Bayesian Neural Networks

This module implements comprehensive metrics for evaluating:
- Classification performance (AUROC)
- Uncertainty quality (AUPR-Success, AUPR-Error)
- Out-of-distribution detection (AUPR-In, AUPR-Out, AUROC-OOD)
- Calibration (ECE)
- Both STD-based and MI-based uncertainty measures
"""

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score


# ============================================================================
# Classification Metrics
# ============================================================================

def compute_nll(y_true: np.ndarray,
                y_pred: np.ndarray,
                epsilon: float = 1e-12) -> float:
    """
    Negative Log Likelihood (NLL) for binary classification.

    Formula: NLL = -(1/N) Σ [y_i * log(p_i) + (1 - y_i) * log(1 - p_i)]

    Lower is better: 0 = perfect predictions

    Parameters:
    -----------
    y_true : (N,) true labels {0, 1}
    y_pred : (N,) predicted probabilities [0, 1]
    epsilon : small constant to avoid log(0)

    Returns:
    --------
    nll : float, negative log likelihood
    """
    # Clip predictions to avoid log(0)
    y_pred_clipped = np.clip(y_pred, epsilon, 1 - epsilon)

    # Compute binary cross-entropy loss
    nll = -np.mean(
        y_true * np.log(y_pred_clipped) + (1 - y_true) * np.log(1 - y_pred_clipped)
    )

    return nll


def compute_auroc(y_true: np.ndarray,
                  y_pred: np.ndarray) -> float:
    """
    AUROC: Probability that a randomly chosen positive has higher predicted
    probability than a randomly chosen negative.

    Formula: P(p̂(y₁=1|x₁) > p̂(y₂=1|x₂)) where y₁=1, y₂=0

    Range: [0.5, 1.0]
    Higher is better: 1.0 = perfect, 0.5 = random

    Parameters:
    -----------
    y_true : (N,) true labels {0, 1}
    y_pred : (N,) predicted probabilities [0, 1]

    Returns:
    --------
    auroc : float in [0.5, 1.0]
    """
    return roc_auc_score(y_true, y_pred)


# ============================================================================
# Uncertainty Quality Metrics (STD-based)
# ============================================================================

def compute_aupr_success(y_true: np.ndarray,
                         y_pred: np.ndarray,
                         uncertainties: np.ndarray) -> float:
    """
    AUPR-Success: How well LOW uncertainty identifies CORRECT predictions.

    Process:
    1. Label: correct_prediction = 1 if round(p̂) == y_true, else 0
    2. Score: -uncertainty (negative because low uncertainty = high score)
    3. Compute AUPR with correct predictions as "positives"

    Interpretation: P(u(x_correct) < u(x_wrong))
    Higher is better: 1.0 = perfect separation

    Parameters:
    -----------
    y_true : (N,) true labels
    y_pred : (N,) predicted probabilities
    uncertainties : (N,) uncertainty scores

    Returns:
    --------
    aupr_success : float in [0, 1]
    """
    # Determine which predictions are correct
    predictions_binary = (y_pred >= 0.5).astype(int)
    correct = (predictions_binary == y_true).astype(int)

    # Use negative uncertainty as score (low uncertainty = high score for correct)
    # This way, correct predictions should rank at the top
    scores = -uncertainties

    # Compute AUPR with correct predictions as positive class
    aupr = average_precision_score(correct, scores)

    return aupr


def compute_aupr_error(y_true: np.ndarray,
                       y_pred: np.ndarray,
                       uncertainties: np.ndarray) -> float:
    """
    AUPR-Error: How well HIGH uncertainty identifies INCORRECT predictions.

    Process:
    1. Label: error = 1 if round(p̂) != y_true, else 0
    2. Score: uncertainty (positive because high uncertainty = errors)
    3. Compute AUPR with errors as "positives"

    Interpretation: P(u(x_error) > u(x_correct))
    Higher is better: 1.0 = perfect error detection

    Parameters:
    -----------
    y_true : (N,) true labels
    y_pred : (N,) predicted probabilities
    uncertainties : (N,) uncertainty scores

    Returns:
    --------
    aupr_error : float in [0, 1]
    """
    # Determine which predictions are errors
    predictions_binary = (y_pred >= 0.5).astype(int)
    errors = (predictions_binary != y_true).astype(int)

    # Use positive uncertainty as score (high uncertainty = high score for errors)
    scores = uncertainties

    # Compute AUPR with errors as positive class
    aupr = average_precision_score(errors, scores)

    return aupr


# ============================================================================
# Out-of-Distribution Detection Metrics (STD-based)
# ============================================================================

def compute_aupr_in_domain(uncertainties_in: np.ndarray,
                           uncertainties_out: np.ndarray) -> float:
    """
    AUPR-In: How well LOW uncertainty identifies IN-DOMAIN samples.

    Process:
    1. Concatenate in-domain and OOD uncertainties
    2. Label: in_domain = 1 for in-domain, 0 for OOD
    3. Score: -uncertainty (low uncertainty = in-domain)
    4. Compute AUPR with in-domain as "positives"

    Interpretation: P(u(x_in) < u(x_out))
    Higher is better: 1.0 = perfect in-domain detection

    Parameters:
    -----------
    uncertainties_in : (N_in,) uncertainties for in-domain data
    uncertainties_out : (N_out,) uncertainties for OOD data

    Returns:
    --------
    aupr_in : float in [0, 1]
    """
    # Concatenate all uncertainties
    all_uncertainties = np.concatenate([uncertainties_in, uncertainties_out])

    # Create labels: 1 for in-domain, 0 for OOD
    in_domain_labels = np.concatenate([
        np.ones(len(uncertainties_in)),
        np.zeros(len(uncertainties_out))
    ])

    # Use negative uncertainty (low uncertainty should identify in-domain)
    scores = -all_uncertainties

    # Compute AUPR
    aupr = average_precision_score(in_domain_labels, scores)

    return aupr


def compute_aupr_ood(uncertainties_in: np.ndarray,
                     uncertainties_out: np.ndarray) -> float:
    """
    AUPR-Out: How well HIGH uncertainty identifies OUT-OF-DOMAIN samples.

    Process:
    1. Concatenate in-domain and OOD uncertainties
    2. Label: out_domain = 1 for OOD, 0 for in-domain
    3. Score: uncertainty (high uncertainty = OOD)
    4. Compute AUPR with OOD as "positives"

    Interpretation: P(u(x_out) > u(x_in))
    Higher is better: 1.0 = perfect OOD detection

    Parameters:
    -----------
    uncertainties_in : (N_in,) uncertainties for in-domain data
    uncertainties_out : (N_out,) uncertainties for OOD data

    Returns:
    --------
    aupr_out : float in [0, 1]
    """
    # Concatenate all uncertainties
    all_uncertainties = np.concatenate([uncertainties_in, uncertainties_out])

    # Create labels: 1 for OOD, 0 for in-domain
    ood_labels = np.concatenate([
        np.zeros(len(uncertainties_in)),
        np.ones(len(uncertainties_out))
    ])

    # Use positive uncertainty (high uncertainty should identify OOD)
    scores = all_uncertainties

    # Compute AUPR
    aupr = average_precision_score(ood_labels, scores)

    return aupr


def compute_auroc_ood(unc_in, unc_out):
    """
    AUROC for OOD: P(u(x_out) > u(x_in))

    Parameters:
    -----------
    unc_in : (N_in,) uncertainties for in-domain data
    unc_out : (N_out,) uncertainties for OOD data

    Returns:
    --------
    auroc_ood : float in [0.5, 1.0]
    """
    all_unc = np.concatenate([unc_in, unc_out])
    labels = np.concatenate([np.zeros(len(unc_in)), np.ones(len(unc_out))])
    return roc_auc_score(labels, all_unc)


# ============================================================================
# Calibration Metrics
# ============================================================================

def compute_ece(y_true: np.ndarray,
                y_pred: np.ndarray,
                n_bins: int = 15) -> float:
    """
    ECE: Expected Calibration Error - measures whether predicted probabilities
    match actual accuracy.

    Process:
    1. Bin predictions into n_bins equally-spaced bins (equal-width version)
    2. For each bin: compute average confidence and actual accuracy
    3. ECE = Σ (bin_size/N) * |confidence - accuracy|

    Formula: ECE = Σ (|Bᵢ|/N) * |acc(Bᵢ) - conf(Bᵢ)|

    Range: [0, 1]
    Lower is better: 0 = perfectly calibrated

    Parameters:
    -----------
    y_true : (N,) true labels
    y_pred : (N,) predicted probabilities
    n_bins : number of bins for calibration

    Returns:
    --------
    ece : float in [0, 1]
    """
    # Create bins
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0

    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Find predictions in this bin
        in_bin = (y_pred > bin_lower) & (y_pred <= bin_upper)
        prop_in_bin = in_bin.mean()

        if prop_in_bin > 0:
            # Average confidence in this bin
            confidence_in_bin = y_pred[in_bin].mean()

            # Actual accuracy in this bin
            accuracy_in_bin = y_true[in_bin].mean()

            # Add weighted difference to ECE
            ece += prop_in_bin * abs(accuracy_in_bin - confidence_in_bin)

    return ece


def compute_ece_equal_mass(y_true: np.ndarray,
                           y_pred: np.ndarray,
                           n_bins: int = 15) -> float:
    """
    ECE: Expected Calibration Error with EQUAL-MASS binning

    Equal-mass binning (quantile-based):
    - Each bin contains approximately the same number of samples
    - Bins defined by quantiles of the predicted probabilities
    - More stable ECE estimates than equal-width binning

    Process:
    1. Compute quantiles of predictions to create n_bins
    2. For each bin: compute average confidence and actual accuracy
    3. ECE = Σ (bin_size/N) * |confidence - accuracy|

    Formula: ECE = Σ (|Bᵢ|/N) * |acc(Bᵢ) - conf(Bᵢ)|

    Range: [0, 1]
    Lower is better: 0 = perfectly calibrated

    Parameters:
    -----------
    y_true : (N,) true labels (0 or 1)
    y_pred : (N,) predicted probabilities [0, 1]
    n_bins : number of bins for calibration

    Returns:
    --------
    ece : float in [0, 1]
    """
    # Get bin boundaries based on quantiles of predictions
    quantiles = np.linspace(0, 1, n_bins + 1)
    bin_boundaries = np.quantile(y_pred, quantiles)

    # Ensure boundaries are exactly at 0 and 1
    bin_boundaries[0] = 0.0
    bin_boundaries[-1] = 1.0

    # Handle duplicates (can happen with discrete predictions)
    bin_boundaries = np.unique(bin_boundaries)

    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0

    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Find predictions in this bin
        # Use >= for lower bound to include boundary points
        in_bin = (y_pred >= bin_lower) & (y_pred <= bin_upper)
        prop_in_bin = in_bin.mean()

        if prop_in_bin > 0:
            # Average confidence in this bin
            confidence_in_bin = y_pred[in_bin].mean()

            # Actual accuracy in this bin
            accuracy_in_bin = y_true[in_bin].mean()

            # Add weighted difference to ECE
            ece += prop_in_bin * abs(accuracy_in_bin - confidence_in_bin)

    return ece


# ============================================================================
# Mutual Information (MI) Based Metrics
# ============================================================================

def compute_aupr_success_mi(y_true: np.ndarray,
                            y_pred: np.ndarray,
                            mi: np.ndarray) -> float:
    """
    AUPR-Success using MI: How well LOW MI identifies CORRECT predictions.

    Parameters:
    -----------
    y_true : (N,) true labels
    y_pred : (N,) predicted probabilities
    mi : (N,) mutual information scores

    Returns:
    --------
    aupr_success_mi : float in [0, 1]
    """
    predictions_binary = (y_pred >= 0.5).astype(int)
    correct = (predictions_binary == y_true).astype(int)
    scores = -mi  # Low MI = high score for correct predictions
    return average_precision_score(correct, scores)


def compute_aupr_error_mi(y_true: np.ndarray,
                          y_pred: np.ndarray,
                          mi: np.ndarray) -> float:
    """
    AUPR-Error using MI: How well HIGH MI identifies INCORRECT predictions.

    Parameters:
    -----------
    y_true : (N,) true labels
    y_pred : (N,) predicted probabilities
    mi : (N,) mutual information scores

    Returns:
    --------
    aupr_error_mi : float in [0, 1]
    """
    predictions_binary = (y_pred >= 0.5).astype(int)
    errors = (predictions_binary != y_true).astype(int)
    scores = mi  # High MI = high score for errors
    return average_precision_score(errors, scores)


def compute_auroc_ood_mi(mi_in: np.ndarray,
                         mi_out: np.ndarray) -> float:
    """
    AUROC for OOD detection using MI: P(MI(x_out) > MI(x_in))

    Parameters:
    -----------
    mi_in : (N_in,) mutual information for in-domain data
    mi_out : (N_out,) mutual information for OOD data

    Returns:
    --------
    auroc_ood_mi : float in [0.5, 1.0]
    """
    all_mi = np.concatenate([mi_in, mi_out])
    labels = np.concatenate([
        np.zeros(len(mi_in)),   # 0 = in-domain
        np.ones(len(mi_out))    # 1 = OOD
    ])
    return roc_auc_score(labels, all_mi)


def compute_aupr_in_domain_mi(mi_in: np.ndarray,
                               mi_out: np.ndarray) -> float:
    """
    AUPR-In using MI: How well LOW MI identifies IN-DOMAIN samples.

    Parameters:
    -----------
    mi_in : (N_in,) mutual information for in-domain data
    mi_out : (N_out,) mutual information for OOD data

    Returns:
    --------
    aupr_in_mi : float in [0, 1]
    """
    all_mi = np.concatenate([mi_in, mi_out])
    in_domain_labels = np.concatenate([
        np.ones(len(mi_in)),
        np.zeros(len(mi_out))
    ])
    scores = -all_mi  # Low MI = in-domain
    return average_precision_score(in_domain_labels, scores)


def compute_aupr_ood_mi(mi_in: np.ndarray,
                        mi_out: np.ndarray) -> float:
    """
    AUPR-Out using MI: How well HIGH MI identifies OOD samples.

    Parameters:
    -----------
    mi_in : (N_in,) mutual information for in-domain data
    mi_out : (N_out,) mutual information for OOD data

    Returns:
    --------
    aupr_out_mi : float in [0, 1]
    """
    all_mi = np.concatenate([mi_in, mi_out])
    ood_labels = np.concatenate([
        np.zeros(len(mi_in)),
        np.ones(len(mi_out))
    ])
    scores = all_mi  # High MI = OOD
    return average_precision_score(ood_labels, scores)


# ============================================================================
# ECE Optimization Helper
# ============================================================================

def find_best_ece(y_true: np.ndarray,
                  y_pred: np.ndarray,
                  bin_configs: list = None) -> dict:
    """
    Find the best ECE configuration by trying different binning strategies and bin counts.

    Parameters:
    -----------
    y_true : (N,) true labels
    y_pred : (N,) predicted probabilities
    bin_configs : list of (method, n_bins) tuples to try
                  method can be 'equal_width' or 'equal_mass'

    Returns:
    --------
    results : dict with ECE values for each configuration and the best config
    """
    if bin_configs is None:
        # Default configurations to try
        bin_configs = [
            ('equal_width', 10),
            ('equal_width', 15),
            ('equal_width', 20),
            ('equal_mass', 10),
            ('equal_mass', 15),
            ('equal_mass', 20)
        ]

    results = {}
    best_ece = float('inf')
    best_config = None

    for method, n_bins in bin_configs:
        if method == 'equal_width':
            ece = compute_ece(y_true, y_pred, n_bins)
        elif method == 'equal_mass':
            ece = compute_ece_equal_mass(y_true, y_pred, n_bins)
        else:
            raise ValueError(f"Unknown binning method: {method}")

        config_name = f"{method}_{n_bins}bins"
        results[config_name] = ece

        if ece < best_ece:
            best_ece = ece
            best_config = config_name

    results['best_config'] = best_config
    results['best_ece'] = best_ece

    return results
