"""
Probability Calibration Framework for CNCRC

This module implements probability calibration methods to improve the quality
of model probability estimates. It provides Temperature Scaling and Isotonic
Regression calibration techniques with comprehensive evaluation metrics.

Key Components:
- CalibratorBase: Abstract base class for calibrators
- TemperatureScaling: Temperature scaling calibration
- IsotonicCalibrator: Isotonic regression calibration  
- CalibrationEvaluator: Evaluation metrics and diagnostics
- CalibrationConfig: Configuration and hyperparameters

The calibration methods help ensure that predicted probabilities are well-calibrated,
meaning that among all predictions where the model assigns probability p, 
approximately p proportion should be correct.
"""

import os
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any, Union
from pathlib import Path
import pickle
import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
import warnings

import scipy.optimize as opt
from scipy.special import softmax, logsumexp
from sklearn.isotonic import IsotonicRegression
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss, log_loss
import matplotlib.pyplot as plt

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class CalibrationConfig:
    """Configuration for probability calibration."""
    
    # Temperature Scaling
    temp_initial: float = 1.0          # Initial temperature value
    temp_bounds: Tuple[float, float] = (0.1, 10.0)  # Temperature bounds
    temp_method: str = "L-BFGS-B"      # Optimization method
    temp_max_iter: int = 1000          # Maximum iterations
    temp_tol: float = 1e-6             # Convergence tolerance
    
    # Isotonic Regression
    isotonic_y_min: Optional[float] = None     # Minimum y value
    isotonic_y_max: Optional[float] = None     # Maximum y value
    isotonic_increasing: bool = True           # Monotonicity constraint
    isotonic_out_of_bounds: str = "clip"      # Out of bounds handling
    
    # Evaluation
    n_bins: int = 15                   # Number of bins for reliability diagram
    bin_strategy: str = "uniform"      # "uniform" or "quantile"
    confidence_level: float = 0.95     # Confidence level for intervals
    
    # Output
    save_plots: bool = True            # Save calibration plots
    plot_dir: str = "plots/calibration"  # Plot directory
    save_models: bool = True           # Save fitted calibrators
    model_dir: str = "models/calibration"  # Model directory


class CalibrationEvaluator:
    """
    Evaluation metrics and diagnostics for probability calibration.
    
    Provides comprehensive evaluation of calibration quality including
    Expected Calibration Error, Brier Score, and reliability diagrams.
    """
    
    def __init__(self, config: CalibrationConfig):
        """
        Initialize calibration evaluator.
        
        Args:
            config: Calibration configuration
        """
        self.config = config
        
        # Create output directories
        Path(config.plot_dir).mkdir(parents=True, exist_ok=True)
        
    def expected_calibration_error(
        self, 
        y_true: np.ndarray, 
        y_prob: np.ndarray,
        n_bins: Optional[int] = None
    ) -> Dict[str, float]:
        """
        Calculate Expected Calibration Error (ECE).
        
        Args:
            y_true: True binary labels (0 or 1)
            y_prob: Predicted probabilities
            n_bins: Number of bins (uses config default if None)
            
        Returns:
            Dictionary with ECE and related metrics
        """
        n_bins = n_bins or self.config.n_bins
        
        # Create bins
        if self.config.bin_strategy == "uniform":
            bin_boundaries = np.linspace(0, 1, n_bins + 1)
        else:  # quantile
            bin_boundaries = np.quantile(y_prob, np.linspace(0, 1, n_bins + 1))
            bin_boundaries[0] = 0.0
            bin_boundaries[-1] = 1.0
        
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        ece = 0.0
        mce = 0.0  # Maximum Calibration Error
        bin_stats = []
        
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            # Find samples in this bin
            in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper)
            prop_in_bin = in_bin.mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = y_true[in_bin].mean()
                avg_confidence_in_bin = y_prob[in_bin].mean()
                
                # Calibration error for this bin
                bin_error = abs(avg_confidence_in_bin - accuracy_in_bin)
                
                # Weight by proportion of samples in bin
                ece += bin_error * prop_in_bin
                mce = max(mce, bin_error)
                
                bin_stats.append({
                    'bin_lower': bin_lower,
                    'bin_upper': bin_upper,
                    'count': in_bin.sum(),
                    'proportion': prop_in_bin,
                    'accuracy': accuracy_in_bin,
                    'confidence': avg_confidence_in_bin,
                    'error': bin_error
                })
        
        return {
            'ece': ece,
            'mce': mce,
            'n_bins': len(bin_stats),
            'bin_stats': bin_stats
        }
    
    def reliability_diagram(
        self,
        y_true: np.ndarray,
        y_prob: np.ndarray,
        title: str = "Reliability Diagram",
        save_path: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Create reliability diagram (calibration plot).
        
        Args:
            y_true: True binary labels
            y_prob: Predicted probabilities
            title: Plot title
            save_path: Path to save plot
            
        Returns:
            Dictionary with plot data and metrics
        """
        # Calculate calibration curve
        fraction_pos, mean_pred_prob = calibration_curve(
            y_true, y_prob, 
            n_bins=self.config.n_bins,
            strategy=self.config.bin_strategy
        )
        
        # Calculate ECE
        ece_results = self.expected_calibration_error(y_true, y_prob)
        
        # Create plot
        fig, ax = plt.subplots(figsize=(8, 8))
        
        # Perfect calibration line
        ax.plot([0, 1], [0, 1], "k--", label="Perfect calibration")
        
        # Calibration curve
        ax.plot(mean_pred_prob, fraction_pos, "o-", label=f"Model (ECE={ece_results['ece']:.3f})")
        
        # Histogram of predictions
        ax2 = ax.twinx()
        ax2.hist(y_prob, bins=self.config.n_bins, alpha=0.3, color="gray", 
                label="Prediction distribution")
        ax2.set_ylabel("Count")
        
        # Formatting
        ax.set_xlabel("Mean Predicted Probability")
        ax.set_ylabel("Fraction of Positives")
        ax.set_title(title)
        ax.legend(loc="upper left")
        ax2.legend(loc="upper right")
        ax.grid(True, alpha=0.3)
        
        # Save plot
        if save_path or self.config.save_plots:
            save_path = save_path or Path(self.config.plot_dir) / f"reliability_{title.lower().replace(' ', '_')}.png"
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
            logger.info(f"Reliability diagram saved to {save_path}")
        
        plt.close()
        
        return {
            'fraction_pos': fraction_pos,
            'mean_pred_prob': mean_pred_prob,
            'ece_results': ece_results,
            'brier_score': brier_score_loss(y_true, y_prob)
        }
    
    def evaluate_calibration(
        self,
        y_true: np.ndarray,
        y_prob_before: np.ndarray,
        y_prob_after: np.ndarray,
        method_name: str = "Calibration"
    ) -> Dict[str, Any]:
        """
        Comprehensive calibration evaluation.
        
        Args:
            y_true: True labels
            y_prob_before: Probabilities before calibration
            y_prob_after: Probabilities after calibration
            method_name: Name of calibration method
            
        Returns:
            Evaluation results
        """
        results = {}
        
        # Before calibration
        ece_before = self.expected_calibration_error(y_true, y_prob_before)
        brier_before = brier_score_loss(y_true, y_prob_before)
        
        # After calibration
        ece_after = self.expected_calibration_error(y_true, y_prob_after)
        brier_after = brier_score_loss(y_true, y_prob_after)
        
        # Calculate improvement
        ece_improvement = ece_before['ece'] - ece_after['ece']
        brier_improvement = brier_before - brier_after
        
        results = {
            'method': method_name,
            'before': {
                'ece': ece_before['ece'],
                'mce': ece_before['mce'],
                'brier_score': brier_before
            },
            'after': {
                'ece': ece_after['ece'],
                'mce': ece_after['mce'],
                'brier_score': brier_after
            },
            'improvement': {
                'ece': ece_improvement,
                'brier_score': brier_improvement,
                'ece_relative': ece_improvement / ece_before['ece'] if ece_before['ece'] > 0 else 0.0
            }
        }
        
        # Create reliability diagrams
        if self.config.save_plots:
            self.reliability_diagram(
                y_true, y_prob_before, 
                title=f"{method_name} - Before",
                save_path=Path(self.config.plot_dir) / f"{method_name.lower()}_before.png"
            )
            
            self.reliability_diagram(
                y_true, y_prob_after,
                title=f"{method_name} - After", 
                save_path=Path(self.config.plot_dir) / f"{method_name.lower()}_after.png"
            )
        
        logger.info(f"{method_name} Evaluation:")
        logger.info(f"  ECE: {ece_before['ece']:.4f} → {ece_after['ece']:.4f} (Δ={ece_improvement:.4f})")
        logger.info(f"  Brier: {brier_before:.4f} → {brier_after:.4f} (Δ={brier_improvement:.4f})")
        
        return results


class CalibratorBase(ABC):
    """Abstract base class for probability calibrators."""
    
    def __init__(self, config: CalibrationConfig):
        """Initialize calibrator with configuration."""
        self.config = config
        self.is_fitted = False
        self.evaluator = CalibrationEvaluator(config)
    
    @abstractmethod
    def fit(self, logits: np.ndarray, labels: np.ndarray) -> 'CalibratorBase':
        """
        Fit calibrator to data.
        
        Args:
            logits: Input logits or probabilities
            labels: True labels
            
        Returns:
            Self for method chaining
        """
        pass
    
    @abstractmethod
    def predict_proba(self, logits: np.ndarray) -> np.ndarray:
        """
        Apply calibration to get calibrated probabilities.
        
        Args:
            logits: Input logits or probabilities
            
        Returns:
            Calibrated probabilities
        """
        pass
    
    def save(self, filepath: str) -> None:
        """Save fitted calibrator to file."""
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before saving")
        
        filepath = Path(filepath)
        filepath.parent.mkdir(parents=True, exist_ok=True)
        
        with open(filepath, 'wb') as f:
            pickle.dump(self, f)
        
        logger.info(f"Calibrator saved to {filepath}")
    
    @classmethod
    def load(cls, filepath: str) -> 'CalibratorBase':
        """Load fitted calibrator from file."""
        with open(filepath, 'rb') as f:
            calibrator = pickle.load(f)
        
        if not calibrator.is_fitted:
            raise ValueError("Loaded calibrator is not fitted")
        
        logger.info(f"Calibrator loaded from {filepath}")
        return calibrator


class TemperatureScaling(CalibratorBase):
    """
    Temperature Scaling calibration method.
    
    Calibrates by scaling logits with a temperature parameter T:
    calibrated_probs = softmax(logits / T)
    
    The temperature T is found by minimizing negative log-likelihood
    on a validation set.
    """
    
    def __init__(self, config: CalibrationConfig):
        """Initialize temperature scaling calibrator."""
        super().__init__(config)
        self.temperature = config.temp_initial
        self.optimization_result = None
    
    def _negative_log_likelihood(self, temperature: float, logits: np.ndarray, labels: np.ndarray) -> float:
        """
        Calculate negative log-likelihood for given temperature.
        
        Args:
            temperature: Temperature parameter
            logits: Input logits
            labels: True labels (one-hot or integer)
            
        Returns:
            Negative log-likelihood
        """
        # Scale logits by temperature
        scaled_logits = logits / temperature
        
        # Convert to probabilities
        log_probs = scaled_logits - logsumexp(scaled_logits, axis=1, keepdims=True)
        
        # Handle both one-hot and integer labels
        if labels.ndim == 2:
            # One-hot labels
            nll = -np.sum(labels * log_probs) / len(labels)
        else:
            # Integer labels
            nll = -np.mean(log_probs[np.arange(len(labels)), labels])
        
        return nll
    
    def fit(self, logits: np.ndarray, labels: np.ndarray) -> 'TemperatureScaling':
        """
        Fit temperature scaling to validation data.
        
        Args:
            logits: Validation logits [n_samples, n_classes]
            labels: Validation labels [n_samples] or [n_samples, n_classes]
            
        Returns:
            Self for method chaining
        """
        logits = np.asarray(logits)
        labels = np.asarray(labels)
        
        # Ensure 2D logits
        if logits.ndim == 1:
            logits = logits.reshape(-1, 1)
        
        logger.info("Fitting temperature scaling...")
        
        # Optimize temperature
        result = opt.minimize(
            fun=self._negative_log_likelihood,
            x0=self.config.temp_initial,
            args=(logits, labels),
            method=self.config.temp_method,
            bounds=[self.config.temp_bounds],
            options={
                'maxiter': self.config.temp_max_iter,
                'ftol': self.config.temp_tol
            }
        )
        
        if not result.success:
            warnings.warn(f"Temperature optimization did not converge: {result.message}")
        
        self.temperature = result.x[0]
        self.optimization_result = result
        self.is_fitted = True
        
        logger.info(f"Optimal temperature: {self.temperature:.4f}")
        logger.info(f"Final NLL: {result.fun:.4f}")
        
        return self
    
    def predict_proba(self, logits: np.ndarray) -> np.ndarray:
        """
        Apply temperature scaling to get calibrated probabilities.
        
        Args:
            logits: Input logits [n_samples, n_classes]
            
        Returns:
            Calibrated probabilities [n_samples, n_classes]
        """
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before prediction")
        
        logits = np.asarray(logits)
        
        # Ensure 2D
        if logits.ndim == 1:
            logits = logits.reshape(-1, 1)
        
        # Scale by temperature and convert to probabilities
        scaled_logits = logits / self.temperature
        calibrated_probs = softmax(scaled_logits, axis=1)
        
        return calibrated_probs
    
    def get_temperature(self) -> float:
        """Get the fitted temperature parameter."""
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted to get temperature")
        return self.temperature


class IsotonicCalibrator(CalibratorBase):
    """
    Isotonic Regression calibration method.
    
    Applies isotonic regression separately for each class to ensure
    monotonically increasing calibration functions.
    """
    
    def __init__(self, config: CalibrationConfig):
        """Initialize isotonic regression calibrator."""
        super().__init__(config)
        self.regressors = {}
        self.n_classes = None
    
    def fit(self, logits: np.ndarray, labels: np.ndarray) -> 'IsotonicCalibrator':
        """
        Fit isotonic regression calibrators.
        
        Args:
            logits: Input logits or probabilities [n_samples, n_classes]
            labels: True labels [n_samples] or [n_samples, n_classes]
            
        Returns:
            Self for method chaining
        """
        logits = np.asarray(logits)
        labels = np.asarray(labels)
        
        # Convert logits to probabilities if needed
        if np.any(logits < 0) or np.any(logits > 1):
            # Assume these are logits, convert to probabilities
            probs = softmax(logits, axis=1)
        else:
            # Assume these are already probabilities
            probs = logits
        
        # Handle label format
        if labels.ndim == 1:
            # Integer labels -> one-hot
            self.n_classes = probs.shape[1]
            labels_onehot = np.zeros((len(labels), self.n_classes))
            labels_onehot[np.arange(len(labels)), labels] = 1
        else:
            # Already one-hot
            labels_onehot = labels
            self.n_classes = labels.shape[1]
        
        logger.info(f"Fitting isotonic regression for {self.n_classes} classes...")
        
        # Fit separate regressor for each class
        self.regressors = {}
        for class_idx in range(self.n_classes):
            regressor = IsotonicRegression(
                y_min=self.config.isotonic_y_min,
                y_max=self.config.isotonic_y_max,
                increasing=self.config.isotonic_increasing,
                out_of_bounds=self.config.isotonic_out_of_bounds
            )
            
            # Fit regressor for this class
            regressor.fit(probs[:, class_idx], labels_onehot[:, class_idx])
            self.regressors[class_idx] = regressor
        
        self.is_fitted = True
        logger.info("Isotonic regression fitting completed")
        
        return self
    
    def predict_proba(self, logits: np.ndarray) -> np.ndarray:
        """
        Apply isotonic regression to get calibrated probabilities.
        
        Args:
            logits: Input logits or probabilities [n_samples, n_classes]
            
        Returns:
            Calibrated probabilities [n_samples, n_classes]
        """
        if not self.is_fitted:
            raise ValueError("Calibrator must be fitted before prediction")
        
        logits = np.asarray(logits)
        
        # Convert to probabilities if needed
        if np.any(logits < 0) or np.any(logits > 1):
            probs = softmax(logits, axis=1)
        else:
            probs = logits
        
        # Apply isotonic regression for each class
        calibrated_probs = np.zeros_like(probs)
        
        for class_idx in range(self.n_classes):
            calibrated_probs[:, class_idx] = self.regressors[class_idx].predict(
                probs[:, class_idx]
            )
        
        # Renormalize to ensure probabilities sum to 1
        row_sums = calibrated_probs.sum(axis=1, keepdims=True)
        row_sums = np.maximum(row_sums, 1e-10)  # Avoid division by zero
        calibrated_probs = calibrated_probs / row_sums
        
        return calibrated_probs


class CalibrationFramework:
    """
    High-level framework for probability calibration.
    
    Provides easy interface for fitting, evaluating, and comparing
    different calibration methods.
    """
    
    def __init__(self, config: Optional[CalibrationConfig] = None):
        """Initialize calibration framework."""
        self.config = config or CalibrationConfig()
        self.evaluator = CalibrationEvaluator(self.config)
        self.fitted_calibrators = {}
        
        # Create output directories
        Path(self.config.model_dir).mkdir(parents=True, exist_ok=True)
    
    def fit_temperature_scaling(
        self,
        logits: np.ndarray,
        labels: np.ndarray,
        name: str = "temperature_scaling"
    ) -> TemperatureScaling:
        """Fit temperature scaling calibrator."""
        calibrator = TemperatureScaling(self.config)
        calibrator.fit(logits, labels)
        
        self.fitted_calibrators[name] = calibrator
        
        if self.config.save_models:
            save_path = Path(self.config.model_dir) / f"{name}.pkl"
            calibrator.save(save_path)
        
        return calibrator
    
    def fit_isotonic_regression(
        self,
        logits: np.ndarray,
        labels: np.ndarray,
        name: str = "isotonic_regression"
    ) -> IsotonicCalibrator:
        """Fit isotonic regression calibrator."""
        calibrator = IsotonicCalibrator(self.config)
        calibrator.fit(logits, labels)
        
        self.fitted_calibrators[name] = calibrator
        
        if self.config.save_models:
            save_path = Path(self.config.model_dir) / f"{name}.pkl"
            calibrator.save(save_path)
        
        return calibrator
    
    def compare_methods(
        self,
        logits_val: np.ndarray,
        labels_val: np.ndarray,
        logits_test: np.ndarray,
        labels_test: np.ndarray
    ) -> Dict[str, Any]:
        """
        Compare calibration methods on test data.
        
        Args:
            logits_val: Validation logits for fitting
            labels_val: Validation labels for fitting
            logits_test: Test logits for evaluation
            labels_test: Test labels for evaluation
            
        Returns:
            Comparison results
        """
        logger.info("Comparing calibration methods...")
        
        # Convert logits to probabilities for baseline
        if np.any(logits_test < 0) or np.any(logits_test > 1):
            probs_uncalibrated = softmax(logits_test, axis=1)
        else:
            probs_uncalibrated = logits_test
        
        # For binary case, extract probabilities for positive class
        if probs_uncalibrated.shape[1] == 2:
            probs_uncalibrated_binary = probs_uncalibrated[:, 1]
        else:
            # For multiclass, use max probability
            probs_uncalibrated_binary = np.max(probs_uncalibrated, axis=1)
        
        # Convert labels to binary format for evaluation
        if labels_test.ndim == 1:
            if probs_uncalibrated.shape[1] == 2:
                labels_binary = labels_test
            else:
                # For multiclass, use confidence in predicted class
                predicted_classes = np.argmax(probs_uncalibrated, axis=1)
                labels_binary = (labels_test == predicted_classes).astype(int)
        else:
            labels_binary = np.argmax(labels_test, axis=1)
        
        results = {}
        
        # Fit and evaluate each method
        methods = ['temperature_scaling', 'isotonic_regression']
        
        for method in methods:
            logger.info(f"Evaluating {method}...")
            
            if method == 'temperature_scaling':
                calibrator = self.fit_temperature_scaling(logits_val, labels_val, method)
            else:
                calibrator = self.fit_isotonic_regression(logits_val, labels_val, method)
            
            # Get calibrated probabilities
            probs_calibrated = calibrator.predict_proba(logits_test)
            
            # Extract probabilities for evaluation
            if probs_calibrated.shape[1] == 2:
                probs_calibrated_binary = probs_calibrated[:, 1]
            else:
                probs_calibrated_binary = np.max(probs_calibrated, axis=1)
            
            # Evaluate calibration
            evaluation = self.evaluator.evaluate_calibration(
                labels_binary,
                probs_uncalibrated_binary,
                probs_calibrated_binary,
                method_name=method.replace('_', ' ').title()
            )
            
            results[method] = evaluation
        
        # Summary comparison
        results['summary'] = self._create_comparison_summary(results, methods)
        
        logger.info("Calibration comparison completed")
        return results
    
    def _create_comparison_summary(
        self,
        results: Dict[str, Any],
        methods: List[str]
    ) -> Dict[str, Any]:
        """Create summary comparison of methods."""
        summary = {
            'best_ece_improvement': None,
            'best_brier_improvement': None,
            'method_ranking': [],
            'comparison_table': []
        }
        
        # Find best improvements
        best_ece_improvement = -float('inf')
        best_brier_improvement = -float('inf')
        
        for method in methods:
            ece_improvement = results[method]['improvement']['ece']
            brier_improvement = results[method]['improvement']['brier_score']
            
            if ece_improvement > best_ece_improvement:
                best_ece_improvement = ece_improvement
                summary['best_ece_improvement'] = method
            
            if brier_improvement > best_brier_improvement:
                best_brier_improvement = brier_improvement
                summary['best_brier_improvement'] = method
        
        # Create ranking based on ECE improvement
        method_scores = [(method, results[method]['improvement']['ece']) for method in methods]
        method_scores.sort(key=lambda x: x[1], reverse=True)
        summary['method_ranking'] = [method for method, _ in method_scores]
        
        # Create comparison table
        for method in methods:
            row = {
                'method': method,
                'ece_before': results[method]['before']['ece'],
                'ece_after': results[method]['after']['ece'],
                'ece_improvement': results[method]['improvement']['ece'],
                'brier_before': results[method]['before']['brier_score'],
                'brier_after': results[method]['after']['brier_score'],
                'brier_improvement': results[method]['improvement']['brier_score']
            }
            summary['comparison_table'].append(row)
        
        return summary


# Convenience functions

def create_synthetic_calibration_data(
    n_samples: int = 1000,
    n_classes: int = 2,
    miscalibration_factor: float = 2.0,
    random_state: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Create synthetic data for calibration testing.
    
    Args:
        n_samples: Number of samples
        n_classes: Number of classes
        miscalibration_factor: Factor controlling miscalibration
        random_state: Random seed
        
    Returns:
        Tuple of (logits, labels, probs_true, probs_miscalibrated)
    """
    np.random.seed(random_state)
    
    # Generate true probabilities
    probs_true = np.random.dirichlet(np.ones(n_classes), n_samples)
    
    # Generate labels from true probabilities
    labels = np.array([np.random.choice(n_classes, p=p) for p in probs_true])
    
    # Create miscalibrated probabilities
    # Add overconfidence by raising probabilities to a power
    probs_miscalibrated = probs_true ** miscalibration_factor
    probs_miscalibrated = probs_miscalibrated / probs_miscalibrated.sum(axis=1, keepdims=True)
    
    # Convert to logits
    logits = np.log(probs_miscalibrated + 1e-10)
    
    return logits, labels, probs_true, probs_miscalibrated


def quick_calibration_comparison(
    logits_val: np.ndarray,
    labels_val: np.ndarray,
    logits_test: np.ndarray,
    labels_test: np.ndarray,
    config: Optional[CalibrationConfig] = None
) -> Dict[str, Any]:
    """
    Quick comparison of calibration methods.
    
    Args:
        logits_val: Validation logits
        labels_val: Validation labels
        logits_test: Test logits
        labels_test: Test labels
        config: Calibration configuration
        
    Returns:
        Comparison results
    """
    framework = CalibrationFramework(config)
    return framework.compare_methods(logits_val, labels_val, logits_test, labels_test)
