"""
Calibration metrics for evaluating model confidence predictions.

Implements four key calibration metrics from the P(IK) paper:
1. ECE (Expected Calibration Error) - 0 is perfect, higher is worse
2. RMS Calibration Error - 0 is perfect, higher is worse  
3. Brier Score - 0 is perfect, higher is worse
4. AUROC (Area Under ROC Curve) - 1 is perfect, 0.5 is random, 0 is worst

These metrics evaluate how well predicted probabilities match actual outcomes.
"""

import numpy as np
from typing import Tuple, List, Union, Optional
import warnings


def expected_calibration_error(
    predictions: np.ndarray, 
    labels: np.ndarray, 
    n_bins: int = 10
) -> float:
    """
    Calculate Expected Calibration Error (ECE).
    
    ECE measures the difference between predicted confidence and actual accuracy
    across different confidence bins. Lower is better (0 = perfect calibration).
    
    Args:
        predictions: Array of predicted probabilities [0, 1]
        labels: Array of binary labels (0 or 1)
        n_bins: Number of bins to use for calibration
        
    Returns:
        ECE value (0 = perfect calibration, higher = worse)
    """
    predictions = np.asarray(predictions)
    labels = np.asarray(labels)
    
    if len(predictions) != len(labels):
        raise ValueError("predictions and labels must have same length")
    
    if len(predictions) == 0:
        raise ValueError("Cannot calculate ECE for empty arrays")
    
    # Create bins
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0.0
    total_samples = len(predictions)
    
    for i, (bin_lower, bin_upper) in enumerate(zip(bin_lowers, bin_uppers)):
        # Find predictions in this bin
        # Use >= for first bin to include 0.0
        if i == 0:
            in_bin = (predictions >= bin_lower) & (predictions <= bin_upper)
        else:
            in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
        prop_in_bin = in_bin.sum() / total_samples
        
        if prop_in_bin > 0:
            # Calculate accuracy and confidence in this bin
            accuracy_in_bin = labels[in_bin].mean()
            avg_confidence_in_bin = predictions[in_bin].mean()
            
            # Add to ECE
            ece += prop_in_bin * abs(avg_confidence_in_bin - accuracy_in_bin)
    
    return float(ece)


def rms_calibration_error(
    predictions: np.ndarray, 
    labels: np.ndarray, 
    n_bins: int = 10
) -> float:
    """
    Calculate Root Mean Square Calibration Error.
    
    Similar to ECE but uses squared differences. Lower is better.
    
    Args:
        predictions: Array of predicted probabilities [0, 1]
        labels: Array of binary labels (0 or 1)
        n_bins: Number of bins to use for calibration
        
    Returns:
        RMS calibration error (0 = perfect calibration, higher = worse)
    """
    predictions = np.asarray(predictions)
    labels = np.asarray(labels)
    
    if len(predictions) != len(labels):
        raise ValueError("predictions and labels must have same length")
    
    # Create bins
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    squared_diffs = []
    weights = []
    
    for i, (bin_lower, bin_upper) in enumerate(zip(bin_lowers, bin_uppers)):
        # Find predictions in this bin
        # Use >= for first bin to include 0.0
        if i == 0:
            in_bin = (predictions >= bin_lower) & (predictions <= bin_upper)
        else:
            in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
        n_in_bin = in_bin.sum()
        
        if n_in_bin > 0:
            # Calculate accuracy and confidence in this bin
            accuracy_in_bin = labels[in_bin].mean()
            avg_confidence_in_bin = predictions[in_bin].mean()
            
            # Store squared difference and weight
            squared_diffs.append((avg_confidence_in_bin - accuracy_in_bin) ** 2)
            weights.append(n_in_bin)
    
    if len(squared_diffs) == 0:
        return 0.0
    
    # Calculate weighted RMS
    weights = np.array(weights) / np.sum(weights)
    rms = np.sqrt(np.sum(weights * squared_diffs))
    
    return float(rms)


def brier_score(predictions: np.ndarray, labels: np.ndarray) -> float:
    """
    Calculate Brier Score.
    
    Brier score is the mean squared difference between predicted probabilities
    and actual outcomes. Lower is better (0 = perfect predictions).
    
    Args:
        predictions: Array of predicted probabilities [0, 1]
        labels: Array of binary labels (0 or 1)
        
    Returns:
        Brier score (0 = perfect predictions, higher = worse)
    """
    predictions = np.asarray(predictions)
    labels = np.asarray(labels)
    
    if len(predictions) != len(labels):
        raise ValueError("predictions and labels must have same length")
    
    return float(np.mean((predictions - labels) ** 2))


def auroc(predictions: np.ndarray, labels: np.ndarray) -> float:
    """
    Calculate Area Under the ROC Curve (AUROC).
    
    AUROC measures the model's ability to distinguish between classes.
    Higher is better (1 = perfect discrimination, 0.5 = random).
    
    Args:
        predictions: Array of predicted probabilities [0, 1]
        labels: Array of binary labels (0 or 1)
        
    Returns:
        AUROC value (1 = perfect, 0.5 = random, 0 = worst)
    """
    predictions = np.asarray(predictions)
    labels = np.asarray(labels)
    
    if len(predictions) != len(labels):
        raise ValueError("predictions and labels must have same length")
    
    # Ensure labels are binary
    unique_labels = np.unique(labels)
    if not (len(unique_labels) <= 2 and all(label in [0, 1] for label in unique_labels)):
        raise ValueError("labels must be binary (0 or 1)")
    
    # Check if we have both classes
    n_pos = np.sum(labels == 1)
    n_neg = np.sum(labels == 0)
    
    if n_pos == 0 or n_neg == 0:
        warnings.warn("Only one class present in labels. AUROC is undefined.")
        return np.nan
    
    # Use simple implementation that handles ties correctly
    # For each positive example, count how many negative examples it ranks above
    auc = 0.0
    for i in range(len(predictions)):
        if labels[i] == 1:
            for j in range(len(predictions)):
                if labels[j] == 0:
                    if predictions[i] > predictions[j]:
                        auc += 1.0
                    elif predictions[i] == predictions[j]:
                        auc += 0.5  # Handle ties
    
    # Normalize by number of positive-negative pairs
    auc = auc / (n_pos * n_neg)
    
    return float(auc)


def compute_all_metrics(
    predictions: np.ndarray, 
    labels: np.ndarray,
    n_bins: int = 10
) -> dict:
    """
    Compute all four calibration metrics at once.
    
    Args:
        predictions: Array of predicted probabilities [0, 1]
        labels: Array of binary labels (0 or 1)
        n_bins: Number of bins for ECE and RMS calculations
        
    Returns:
        Dictionary with all four metrics
    """
    return {
        'ece': expected_calibration_error(predictions, labels, n_bins),
        'rms_calibration_error': rms_calibration_error(predictions, labels, n_bins),
        'brier_score': brier_score(predictions, labels),
        'auroc': auroc(predictions, labels)
    }


def convert_grades_to_binary(grades: List[str], positive_grade: str = 'A') -> np.ndarray:
    """
    Convert letter grades to binary labels.
    
    Args:
        grades: List of grade strings (e.g., ['A', 'B', 'C'])
        positive_grade: Which grade to treat as positive class (default: 'A')
        
    Returns:
        Binary array where positive_grade = 1, others = 0
    """
    return np.array([1 if g == positive_grade else 0 for g in grades])


def calibration_summary(
    predictions: np.ndarray, 
    labels: np.ndarray,
    model_name: str = "Model"
) -> str:
    """
    Generate a human-readable summary of calibration metrics.
    
    Args:
        predictions: Array of predicted probabilities
        labels: Array of binary labels (0 or 1)
        model_name: Name of the model for the summary
        
    Returns:
        Formatted string summary
    """
    metrics = compute_all_metrics(predictions, labels)
    
    summary = f"\n=== Calibration Summary for {model_name} ===\n"
    summary += f"ECE (Expected Calibration Error): {metrics['ece']:.4f} (lower is better, 0 = perfect)\n"
    summary += f"RMS Calibration Error: {metrics['rms_calibration_error']:.4f} (lower is better, 0 = perfect)\n"
    summary += f"Brier Score: {metrics['brier_score']:.4f} (lower is better, 0 = perfect)\n"
    summary += f"AUROC: {metrics['auroc']:.4f} (higher is better, 1 = perfect, 0.5 = random)\n"
    
    # Add interpretation
    if metrics['ece'] < 0.05:
        summary += "\nCalibration: Excellent (ECE < 0.05)"
    elif metrics['ece'] < 0.10:
        summary += "\nCalibration: Good (ECE < 0.10)"
    elif metrics['ece'] < 0.20:
        summary += "\nCalibration: Moderate (ECE < 0.20)"
    else:
        summary += "\nCalibration: Poor (ECE >= 0.20)"
    
    return summary


if __name__ == "__main__":
    # Simple test
    print("Testing calibration metrics...")
    
    # Test data
    predictions = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
    labels = np.array([1, 1, 1, 1, 0, 0, 0, 0, 0])
    
    metrics = compute_all_metrics(predictions, labels)
    print(f"Test metrics: {metrics}")
    
    print(calibration_summary(predictions, labels, "Test Model"))