"""
Baseline Conformal Prediction Methods for CNCRC Comparison

This module implements the baseline conformal prediction methods required by
the CNCRC research plan for fair comparison. All methods are post-hoc and
follow the standard conformal prediction framework.

According to research plan section 4.2, the methods compared are:
- Standard CP: s(x,y) = 1 - P(y|x)
- Heuristic penalized CP: s_λ = (1-P) + λ·Cost  
- Cost-aware CP variants
- CNCRC variants

Key Design Principles:
- All methods are post-hoc (no model fine-tuning)
- Fair comparison with CNCRC
- Standard conformal prediction guarantees
- Compatible with MIMIC-IV × DrugBank experiments
"""

import os
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any, Union, Callable
from pathlib import Path
import pickle
import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
import warnings

from scipy import sparse
from scipy.stats import bootstrap
import matplotlib.pyplot as plt

# Import CNCRC core components
from ..core.data_structures import CostMatrix, PredictionSet, ClinicalContext
from ..core.calibration import calibrate_quantile, calibrate_quantile_detailed
from ..core.prediction_set import build_prediction_set, validate_prediction_set

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class BaselineConfig:
    """Configuration for baseline conformal prediction methods."""
    
    # General CP settings
    alpha: float = 0.1                    # Risk level (non-coverage rate)
    random_state: int = 42                # For reproducibility
    
    # Heuristic Penalized CP
    lambda_values: List[float] = field(default_factory=lambda: [0.1, 0.5, 1.0, 2.0, 5.0])
    lambda_search: str = "grid"           # "grid" or "adaptive"
    
    
    # Cost-aware CP variants
    cost_aware_variants: List[str] = field(default_factory=lambda: ["weighted_quantile", "stratified", "expected_loss"])
    
    # Evaluation
    n_bootstrap: int = 1000               # Bootstrap samples for CIs
    confidence_level: float = 0.95        # Confidence level for intervals
    
    # Output
    save_results: bool = True             # Save detailed results
    results_dir: str = "results/baselines"  # Results directory
    save_plots: bool = True               # Save comparison plots
    plot_dir: str = "plots/baselines"     # Plot directory


class ConformalBaselineBase(ABC):
    """Abstract base class for all conformal prediction baseline methods."""
    
    def __init__(self, config: BaselineConfig):
        """Initialize baseline method with configuration."""
        self.config = config
        self.is_fitted = False
        self.q_threshold = None
        self.method_name = self.__class__.__name__
        
        # Create output directories
        Path(config.results_dir).mkdir(parents=True, exist_ok=True)
        if config.save_plots:
            Path(config.plot_dir).mkdir(parents=True, exist_ok=True)
    
    @abstractmethod
    def compute_nonconformity_score(
        self, 
        probabilities: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix],
        y_true: int
    ) -> float:
        """
        Compute non-conformity score for a single sample.
        
        Args:
            probabilities: Model probabilities P(y|x), shape (n_classes,)
            cost_matrix: Cost matrix or CostMatrix object
            y_true: True label index
            
        Returns:
            Non-conformity score (float)
        """
        pass
    
    def fit(
        self, 
        probabilities_cal: np.ndarray,
        labels_cal: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix],
        **kwargs
    ) -> 'ConformalBaselineBase':
        """
        Fit conformal predictor on calibration data.
        
        Args:
            probabilities_cal: Calibration probabilities [n_cal, n_classes]
            labels_cal: Calibration labels [n_cal]
            cost_matrix: Cost matrix
            **kwargs: Method-specific parameters
            
        Returns:
            Self for method chaining
        """
        # Input validation
        probabilities_cal = np.asarray(probabilities_cal)
        labels_cal = np.asarray(labels_cal)
        
        if probabilities_cal.ndim != 2:
            raise ValueError(f"probabilities_cal must be 2D, got shape {probabilities_cal.shape}")
        
        n_cal, n_classes = probabilities_cal.shape
        
        if len(labels_cal) != n_cal:
            raise ValueError(f"labels_cal length {len(labels_cal)} != n_cal {n_cal}")
        
        # Handle CostMatrix object
        if isinstance(cost_matrix, CostMatrix):
            cost_array = cost_matrix.matrix
        else:
            cost_array = np.asarray(cost_matrix)
        
        if cost_array.shape != (n_classes, n_classes):
            raise ValueError(f"Cost matrix shape {cost_array.shape} != ({n_classes}, {n_classes})")
        
        # Compute non-conformity scores for calibration set
        scores = []
        for i in range(n_cal):
            score = self.compute_nonconformity_score(
                probabilities_cal[i], cost_array, labels_cal[i]
            )
            scores.append(score)
        
        scores = np.array(scores)
        
        # Calibrate quantile threshold
        self.q_threshold = calibrate_quantile(scores, self.config.alpha)
        self.is_fitted = True
        
        logger.info(f"{self.method_name} fitted: q_threshold={self.q_threshold:.6f}")
        
        return self
    
    def predict_set(
        self, 
        probabilities: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix]
    ) -> List[int]:
        """
        Predict conformal set for a single sample.
        
        Args:
            probabilities: Test probabilities P(y|x), shape (n_classes,)
            cost_matrix: Cost matrix
            
        Returns:
            Prediction set as list of class indices
        """
        if not self.is_fitted:
            raise ValueError(f"{self.method_name} must be fitted before prediction")
        
        probabilities = np.asarray(probabilities)
        
        if probabilities.ndim != 1:
            raise ValueError(f"probabilities must be 1D, got shape {probabilities.shape}")
        
        n_classes = len(probabilities)
        
        # Handle CostMatrix object
        if isinstance(cost_matrix, CostMatrix):
            cost_array = cost_matrix.matrix
        else:
            cost_array = np.asarray(cost_matrix)
        
        # Build prediction set
        prediction_set = []
        
        for y_candidate in range(n_classes):
            score = self.compute_nonconformity_score(
                probabilities, cost_array, y_candidate
            )
            
            if score <= self.q_threshold:
                prediction_set.append(y_candidate)
        
        return prediction_set
    
    def predict_set_batch(
        self, 
        probabilities_batch: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix]
    ) -> List[List[int]]:
        """
        Predict conformal sets for multiple samples.
        
        Args:
            probabilities_batch: Test probabilities [n_test, n_classes]
            cost_matrix: Cost matrix
            
        Returns:
            List of prediction sets
        """
        probabilities_batch = np.asarray(probabilities_batch)
        
        if probabilities_batch.ndim != 2:
            raise ValueError(f"probabilities_batch must be 2D, got shape {probabilities_batch.shape}")
        
        prediction_sets = []
        
        for i in range(len(probabilities_batch)):
            pred_set = self.predict_set(probabilities_batch[i], cost_matrix)
            prediction_sets.append(pred_set)
        
        return prediction_sets
    
    def save(self, filepath: str) -> None:
        """Save fitted baseline method."""
        if not self.is_fitted:
            raise ValueError(f"{self.method_name} must be fitted before saving")
        
        filepath = Path(filepath)
        filepath.parent.mkdir(parents=True, exist_ok=True)
        
        with open(filepath, 'wb') as f:
            pickle.dump(self, f)
        
        logger.info(f"{self.method_name} saved to {filepath}")
    
    @classmethod
    def load(cls, filepath: str) -> 'ConformalBaselineBase':
        """Load fitted baseline method."""
        with open(filepath, 'rb') as f:
            method = pickle.load(f)
        
        if not method.is_fitted:
            raise ValueError(f"Loaded {cls.__name__} is not fitted")
        
        logger.info(f"{cls.__name__} loaded from {filepath}")
        return method


class StandardCP(ConformalBaselineBase):
    """
    Standard Conformal Prediction (cost-agnostic baseline).
    
    Non-conformity score: s(x,y) = 1 - P(y|x)
    
    This is the classic conformal prediction method that only considers
    model uncertainty without any cost information.
    """
    
    def compute_nonconformity_score(
        self, 
        probabilities: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix],
        y_true: int
    ) -> float:
        """
        Compute standard CP non-conformity score.
        
        Score: s(x,y) = 1 - P(y|x)
        
        Args:
            probabilities: Model probabilities P(y|x)
            cost_matrix: Cost matrix (unused for standard CP)
            y_true: True label index
            
        Returns:
            Non-conformity score
        """
        probabilities = np.asarray(probabilities)
        
        # Validate probability
        if not 0 <= y_true < len(probabilities):
            raise ValueError(f"y_true={y_true} out of range [0, {len(probabilities)-1}]")
        
        # Standard CP score: higher score = more non-conforming
        score = 1.0 - probabilities[y_true]
        
        return float(score)


class HeuristicPenalizedCP(ConformalBaselineBase):
    """
    Heuristic Penalized Conformal Prediction.
    
    Non-conformity score: s_λ(x,y) = (1 - P(y|x)) + λ * Cost_penalty(y)
    
    This method combines model uncertainty with a simple cost penalty.
    The λ parameter balances uncertainty and cost considerations.
    """
    
    def __init__(self, config: BaselineConfig, lambda_weight: float = 1.0):
        """
        Initialize heuristic penalized CP.
        
        Args:
            config: Baseline configuration
            lambda_weight: Weight for cost penalty term
        """
        super().__init__(config)
        self.lambda_weight = lambda_weight
        self.method_name = f"HeuristicPenalizedCP(λ={lambda_weight})"
    
    def compute_nonconformity_score(
        self, 
        probabilities: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix],
        y_true: int
    ) -> float:
        """
        Compute heuristic penalized CP non-conformity score.
        
        Score: s_λ(x,y) = (1 - P(y|x)) + λ * max_j Cost(y, j)
        
        Args:
            probabilities: Model probabilities P(y|x)
            cost_matrix: Cost matrix 
            y_true: True label index
            
        Returns:
            Non-conformity score
        """
        probabilities = np.asarray(probabilities)
        
        if isinstance(cost_matrix, CostMatrix):
            cost_array = cost_matrix.matrix
        else:
            cost_array = np.asarray(cost_matrix)
        
        if not 0 <= y_true < len(probabilities):
            raise ValueError(f"y_true={y_true} out of range [0, {len(probabilities)-1}]")
        
        # Uncertainty component: 1 - P(y|x)
        uncertainty = 1.0 - probabilities[y_true]
        
        # Cost penalty: λ * max_j Cost(y_true, j)
        cost_penalty = self.lambda_weight * np.max(cost_array[y_true, :])
        
        # Combined score
        score = uncertainty + cost_penalty
        
        return float(score)



class CostAwareCP(ConformalBaselineBase):
    """
    Cost-Aware Conformal Prediction variants.
    
    Implements various cost-aware extensions to conformal prediction,
    including weighted quantile calibration and stratified approaches.
    """
    
    def __init__(self, config: BaselineConfig, variant: str = "weighted_quantile"):
        """
        Initialize cost-aware CP.
        
        Args:
            config: Baseline configuration
            variant: Cost-aware variant ("weighted_quantile", "stratified", "expected_loss")
        """
        super().__init__(config)
        self.variant = variant
        self.method_name = f"CostAwareCP({variant})"
        
        if variant not in ["weighted_quantile", "stratified", "expected_loss"]:
            raise ValueError(f"Unknown variant: {variant}")
    
    def compute_nonconformity_score(
        self, 
        probabilities: np.ndarray, 
        cost_matrix: Union[np.ndarray, CostMatrix],
        y_true: int
    ) -> float:
        """
        Compute cost-aware CP non-conformity score.
        
        Different variants implement different cost integration strategies.
        
        Args:
            probabilities: Model probabilities P(y|x)
            cost_matrix: Cost matrix
            y_true: True label index
            
        Returns:
            Non-conformity score
        """
        probabilities = np.asarray(probabilities)
        
        if isinstance(cost_matrix, CostMatrix):
            cost_array = cost_matrix.matrix
        else:
            cost_array = np.asarray(cost_matrix)
        
        if not 0 <= y_true < len(probabilities):
            raise ValueError(f"y_true={y_true} out of range [0, {len(probabilities)-1}]")
        
        if self.variant == "weighted_quantile":
            # Weighted by maximum cost as importance
            max_cost = np.max(cost_array[y_true, :])
            uncertainty = 1.0 - probabilities[y_true]
            score = uncertainty * (1.0 + max_cost)  # Higher cost = higher weight
            
        elif self.variant == "stratified":
            # Stratify by cost level
            max_cost = np.max(cost_array[y_true, :])
            uncertainty = 1.0 - probabilities[y_true]
            
            # Define cost strata
            if max_cost < 0.3:
                cost_stratum = 0  # Low cost
            elif max_cost < 0.7:
                cost_stratum = 1  # Medium cost
            else:
                cost_stratum = 2  # High cost
            
            # Adjust score by stratum (higher cost = more stringent)
            score = uncertainty * (1.0 + cost_stratum * 0.5)
            
        elif self.variant == "expected_loss":
            # Expected loss under misclassification
            expected_loss = np.sum(probabilities * cost_array[y_true, :])
            score = expected_loss
            
        else:
            raise ValueError(f"Unknown variant: {self.variant}")
        
        return float(score)


class BaselineComparison:
    """
    Framework for comparing all baseline methods with CNCRC.
    
    Provides unified evaluation metrics and statistical testing
    according to research plan section 4.3.
    """
    
    def __init__(self, config: BaselineConfig):
        """Initialize baseline comparison framework."""
        self.config = config
        self.methods = {}
        self.results = {}
        
        # Create output directories
        Path(config.results_dir).mkdir(parents=True, exist_ok=True)
        if config.save_plots:
            Path(config.plot_dir).mkdir(parents=True, exist_ok=True)
    
    def add_method(self, name: str, method: ConformalBaselineBase) -> None:
        """Add a baseline method to comparison."""
        self.methods[name] = method
    
    def evaluate_method(
        self,
        method: ConformalBaselineBase,
        probabilities_test: np.ndarray,
        labels_test: np.ndarray,
        cost_matrix: Union[np.ndarray, CostMatrix],
        method_name: str
    ) -> Dict[str, Any]:
        """
        Evaluate a single method according to research plan metrics.
        
        Research plan section 4.3 metrics:
        - Primary: Empirical risk = uncovered loss + covered ambiguity loss
        - Guarantee check: Empirical non-coverage risk
        - Risk breakdown: Ambiguity risk on covered samples (max Cost)
        - Efficiency: average set size; auxiliary: miscoverage rate
        
        Args:
            method: Fitted baseline method
            probabilities_test: Test probabilities [n_test, n_classes]
            labels_test: Test labels [n_test]
            cost_matrix: Cost matrix
            method_name: Name for logging
            
        Returns:
            Evaluation results dictionary
        """
        probabilities_test = np.asarray(probabilities_test)
        labels_test = np.asarray(labels_test)
        
        n_test = len(labels_test)
        
        if isinstance(cost_matrix, CostMatrix):
            cost_array = cost_matrix.matrix
        else:
            cost_array = np.asarray(cost_matrix)
        
        # Get prediction sets
        prediction_sets = method.predict_set_batch(probabilities_test, cost_matrix)
        
        # Initialize metrics
        uncovered_losses = []
        covered_ambiguity_losses = []
        set_sizes = []
        coverages = []
        
        for i in range(n_test):
            pred_set = prediction_sets[i]
            true_label = labels_test[i]
            
            # Set size
            set_sizes.append(len(pred_set))
            
            # Coverage
            is_covered = true_label in pred_set
            coverages.append(1.0 if is_covered else 0.0)
            
            if not is_covered:
                # Uncovered: compute default penalty (research plan section 4.1)
                # Use maximum cost as default penalty
                uncovered_loss = np.max(cost_array[true_label, :])
                uncovered_losses.append(uncovered_loss)
                covered_ambiguity_losses.append(0.0)  # No ambiguity if uncovered
            else:
                # Covered: compute ambiguity loss (max cost among predictions)
                uncovered_losses.append(0.0)  # No uncovered loss
                
                if len(pred_set) > 1:
                    # Ambiguity among predictions
                    pred_costs = [cost_array[true_label, pred] for pred in pred_set if pred != true_label]
                    ambiguity_loss = max(pred_costs) if pred_costs else 0.0
                else:
                    ambiguity_loss = 0.0  # Perfect prediction
                
                covered_ambiguity_losses.append(ambiguity_loss)
        
        uncovered_losses = np.array(uncovered_losses)
        covered_ambiguity_losses = np.array(covered_ambiguity_losses)
        set_sizes = np.array(set_sizes)
        coverages = np.array(coverages)
        
        # Compute research plan metrics (corrected constraint optimization pattern)
        results = {
            # Constraint check (primary)
            'empirical_non_coverage_risk': np.mean(uncovered_losses),
            'satisfies_risk_constraint': None,  # Will be set by comparison framework
            
            # Optimization target (among constraint-satisfying methods)
            'covered_ambiguity_loss_mean': np.mean(covered_ambiguity_losses),
            
            # Traditional metrics
            'empirical_coverage': np.mean(coverages),
            'miscoverage_rate': 1.0 - np.mean(coverages),
            'uncovered_loss_mean': np.mean(uncovered_losses),
            
            # DEPRECATED: Remove composite metric (was conceptually wrong)
            # 'empirical_risk': np.mean(uncovered_losses + covered_ambiguity_losses),
            
            # Efficiency
            'average_set_size': np.mean(set_sizes),
            'set_size_std': np.std(set_sizes),
            
            # Detailed statistics
            'n_test_samples': n_test,
            'method_name': method_name,
            'alpha': self.config.alpha,
            'theoretical_coverage': 1.0 - self.config.alpha,
            
            # Raw data for further analysis
            'uncovered_losses': uncovered_losses,
            'covered_ambiguity_losses': covered_ambiguity_losses,
            'set_sizes': set_sizes,
            'coverages': coverages,
            'prediction_sets': prediction_sets
        }
        
        # Bootstrap confidence intervals
        if self.config.n_bootstrap > 0:
            bootstrap_results = self._bootstrap_metrics(
                uncovered_losses, covered_ambiguity_losses, set_sizes, coverages
            )
            results.update(bootstrap_results)
        
        logger.info(f"{method_name} Evaluation (Constraint Optimization):")
        logger.info(f"  Non-coverage Risk: {results['empirical_non_coverage_risk']:.4f} (constraint ≤ {self.config.alpha})")
        logger.info(f"  Ambiguity Risk: {results['covered_ambiguity_loss_mean']:.4f} (optimization target)")
        logger.info(f"  Coverage: {results['empirical_coverage']:.4f}")
        logger.info(f"  Avg Set Size: {results['average_set_size']:.2f}")
        logger.info(f"  Satisfies Constraint: {results.get('satisfies_risk_constraint', 'TBD')}")
        
        return results
    
    def _bootstrap_metrics(
        self, 
        uncovered_losses: np.ndarray,
        covered_ambiguity_losses: np.ndarray,
        set_sizes: np.ndarray,
        coverages: np.ndarray
    ) -> Dict[str, Any]:
        """Compute bootstrap confidence intervals for metrics."""
        n = len(uncovered_losses)
        rng = np.random.RandomState(self.config.random_state)
        
        bootstrap_stats = {
            'empirical_non_coverage_risk': [],
            'covered_ambiguity_loss_mean': [],
            'empirical_coverage': [],
            'average_set_size': []
        }
        
        for _ in range(self.config.n_bootstrap):
            # Bootstrap sample
            indices = rng.choice(n, size=n, replace=True)
            
            uncovered_boot = uncovered_losses[indices]
            ambiguity_boot = covered_ambiguity_losses[indices]
            sizes_boot = set_sizes[indices]
            coverage_boot = coverages[indices]
            
            # Compute metrics (corrected constraint optimization perspective)
            bootstrap_stats['empirical_non_coverage_risk'].append(np.mean(uncovered_boot))
            bootstrap_stats['covered_ambiguity_loss_mean'].append(np.mean(ambiguity_boot))
            bootstrap_stats['empirical_coverage'].append(np.mean(coverage_boot))
            bootstrap_stats['average_set_size'].append(np.mean(sizes_boot))
        
        # Compute confidence intervals
        ci_results = {}
        alpha_ci = 1.0 - self.config.confidence_level
        
        for metric, values in bootstrap_stats.items():
            values = np.array(values)
            ci_lower = np.quantile(values, alpha_ci / 2)
            ci_upper = np.quantile(values, 1 - alpha_ci / 2)
            
            ci_results[f'{metric}_ci_lower'] = ci_lower
            ci_results[f'{metric}_ci_upper'] = ci_upper
            ci_results[f'{metric}_std'] = np.std(values)
        
        return ci_results
    
    def compare_all_methods(
        self,
        probabilities_cal: np.ndarray,
        labels_cal: np.ndarray,
        probabilities_test: np.ndarray,
        labels_test: np.ndarray,
        cost_matrix: Union[np.ndarray, CostMatrix]
    ) -> Dict[str, Any]:
        """
        Compare all baseline methods according to research plan.
        
        Args:
            probabilities_cal: Calibration probabilities [n_cal, n_classes]
            labels_cal: Calibration labels [n_cal]
            probabilities_test: Test probabilities [n_test, n_classes]
            labels_test: Test labels [n_test]
            cost_matrix: Cost matrix
            
        Returns:
            Comprehensive comparison results
        """
        logger.info("Starting baseline method comparison...")
        
        # Create all baseline methods
        methods_to_compare = {}
        
        # 1. Standard CP
        methods_to_compare['StandardCP'] = StandardCP(self.config)
        
        # 2. Heuristic Penalized CP (multiple λ values)
        for lambda_val in self.config.lambda_values:
            name = f'HeuristicCP_λ{lambda_val}'
            methods_to_compare[name] = HeuristicPenalizedCP(self.config, lambda_val)
        
        
        # 4. Cost-aware CP variants
        for variant in self.config.cost_aware_variants:
            name = f'CostAwareCP_{variant}'
            methods_to_compare[name] = CostAwareCP(self.config, variant)
        
        # Fit and evaluate all methods
        all_results = {}
        
        for method_name, method in methods_to_compare.items():
            try:
                # Fit method
                method.fit(probabilities_cal, labels_cal, cost_matrix)
                
                # Evaluate method
                results = self.evaluate_method(
                    method, probabilities_test, labels_test, cost_matrix, method_name
                )
                
                all_results[method_name] = results
                
                # Save method if configured
                if self.config.save_results:
                    save_path = Path(self.config.results_dir) / f"{method_name}.pkl"
                    method.save(save_path)
                
            except Exception as e:
                logger.error(f"Failed to evaluate {method_name}: {e}")
                continue
        
        # Create summary comparison
        summary = self._create_comparison_summary(all_results)
        
        # Save results
        if self.config.save_results:
            results_path = Path(self.config.results_dir) / "baseline_comparison_results.json"
            with open(results_path, 'w') as f:
                json.dump({
                    'summary': summary,
                    'detailed_results': {k: self._serialize_results(v) for k, v in all_results.items()},
                    'config': self.config.__dict__,
                    'timestamp': datetime.now().isoformat()
                }, f, indent=2)
            
            logger.info(f"Results saved to {results_path}")
        
        # Generate plots
        if self.config.save_plots:
            self._generate_comparison_plots(all_results)
        
        logger.info("Baseline comparison completed")
        
        return {
            'summary': summary,
            'detailed_results': all_results,
            'config': self.config
        }
    
    def _serialize_results(self, results: Dict[str, Any]) -> Dict[str, Any]:
        """Serialize results for JSON storage."""
        serialized = {}
        for key, value in results.items():
            if isinstance(value, np.ndarray):
                serialized[key] = value.tolist()
            elif isinstance(value, (np.integer, np.floating)):
                serialized[key] = float(value)
            else:
                serialized[key] = value
        return serialized
    
    def _create_comparison_summary(self, all_results: Dict[str, Any]) -> Dict[str, Any]:
        """Create summary comparison table using constraint optimization logic."""
        summary = {
            'constraint_satisfying_methods': [],
            'constraint_violating_methods': [],
            'best_ambiguity_risk_among_satisfying': None,
            'best_coverage': None,
            'best_set_size': None,
            'method_ranking': [],
            'comparison_table': [],
            'crc_threshold': self.config.alpha  # Non-coverage risk threshold
        }
        
        # Step 1: Check constraint satisfaction and find best methods
        best_coverage_gap = float('inf')
        best_set_size = float('inf')
        best_ambiguity_risk = float('inf')
        
        constraint_satisfying = []
        constraint_violating = []
        
        for method_name, results in all_results.items():
            non_coverage_risk = results['empirical_non_coverage_risk']
            ambiguity_risk = results['covered_ambiguity_loss_mean']
            coverage = results['empirical_coverage']
            set_size = results['average_set_size']
            
            # Check constraint satisfaction
            satisfies_constraint = non_coverage_risk <= self.config.alpha
            results['satisfies_risk_constraint'] = satisfies_constraint
            
            if satisfies_constraint:
                constraint_satisfying.append((method_name, ambiguity_risk))
                summary['constraint_satisfying_methods'].append(method_name)
                
                # Track best ambiguity risk among constraint-satisfying methods
                if ambiguity_risk < best_ambiguity_risk:
                    best_ambiguity_risk = ambiguity_risk
                    summary['best_ambiguity_risk_among_satisfying'] = method_name
            else:
                constraint_violating.append((method_name, non_coverage_risk))
                summary['constraint_violating_methods'].append(method_name)
            
            # Track best traditional metrics
            coverage_gap = abs(coverage - (1.0 - self.config.alpha))
            if coverage_gap < best_coverage_gap:
                best_coverage_gap = coverage_gap
                summary['best_coverage'] = method_name
            
            if set_size < best_set_size:
                best_set_size = set_size
                summary['best_set_size'] = method_name
            
            # Add to comparison table
            row = {
                'method': method_name,
                'non_coverage_risk': non_coverage_risk,
                'satisfies_constraint': satisfies_constraint,
                'ambiguity_risk': ambiguity_risk,
                'coverage': coverage,
                'avg_set_size': set_size,
                'q_threshold': getattr(results.get('method_object'), 'q_threshold', None)
            }
            summary['comparison_table'].append(row)
        
        # Step 2: Create ranking (constraint-satisfying methods first, sorted by ambiguity risk)
        constraint_satisfying.sort(key=lambda x: x[1])  # Sort by ambiguity risk
        constraint_violating.sort(key=lambda x: x[1])   # Sort by non-coverage risk (ascending = worse constraint violation)
        
        # Final ranking: satisfying methods first, then violating methods
        ranking = [method for method, _ in constraint_satisfying] + [method for method, _ in constraint_violating]
        summary['method_ranking'] = ranking
        
        return summary
    
    def _generate_comparison_plots(self, all_results: Dict[str, Any]) -> None:
        """Generate comparison plots using constraint optimization perspective."""
        try:
            # 1. Constraint Satisfaction vs Ambiguity Risk
            fig, ax = plt.subplots(figsize=(12, 8))
            
            methods = list(all_results.keys())
            ambiguity_risks = [all_results[m]['covered_ambiguity_loss_mean'] for m in methods]
            non_coverage_risks = [all_results[m]['empirical_non_coverage_risk'] for m in methods]
            satisfies_constraint = [all_results[m]['satisfies_risk_constraint'] for m in methods]
            
            # Color by constraint satisfaction
            colors = ['green' if satisfies else 'red' for satisfies in satisfies_constraint]
            
            scatter = ax.scatter(ambiguity_risks, non_coverage_risks, c=colors, alpha=0.7, s=100)
            
            # Add constraint line
            ax.axhline(y=self.config.alpha, color='blue', linestyle='--', 
                      label=f'Constraint Threshold (α={self.config.alpha})')
            
            # Annotate points
            for i, method in enumerate(methods):
                ax.annotate(method, (ambiguity_risks[i], non_coverage_risks[i]), 
                           xytext=(5, 5), textcoords='offset points', fontsize=8)
            
            ax.set_xlabel('Ambiguity Risk (optimization target)')
            ax.set_ylabel('Non-Coverage Risk (constraint)')
            ax.set_title('Constraint Optimization Perspective: Non-Coverage Risk vs Ambiguity Risk')
            ax.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(Path(self.config.plot_dir) / 'risk_vs_setsize.png', 
                       dpi=300, bbox_inches='tight')
            plt.close()
            
            # 2. Coverage Comparison
            fig, ax = plt.subplots(figsize=(12, 6))
            
            coverages = [all_results[m]['empirical_coverage'] for m in methods]
            target_coverage = 1.0 - self.config.alpha
            
            x_pos = np.arange(len(methods))
            bars = ax.bar(x_pos, coverages, alpha=0.7)
            ax.axhline(y=target_coverage, color='red', linestyle='--', 
                      label=f'Target Coverage ({target_coverage:.3f})')
            
            ax.set_xlabel('Methods')
            ax.set_ylabel('Empirical Coverage')
            ax.set_title('Coverage Comparison Across Baseline Methods')
            ax.set_xticks(x_pos)
            ax.set_xticklabels(methods, rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(Path(self.config.plot_dir) / 'coverage_comparison.png', 
                       dpi=300, bbox_inches='tight')
            plt.close()
            
            logger.info(f"Comparison plots saved to {self.config.plot_dir}")
            
        except Exception as e:
            logger.error(f"Failed to generate plots: {e}")


# Convenience functions

def create_baseline_comparison(config: Optional[BaselineConfig] = None) -> BaselineComparison:
    """Create a baseline comparison framework with default configuration."""
    if config is None:
        config = BaselineConfig()
    return BaselineComparison(config)


def quick_baseline_comparison(
    probabilities_cal: np.ndarray,
    labels_cal: np.ndarray,
    probabilities_test: np.ndarray,
    labels_test: np.ndarray,
    cost_matrix: Union[np.ndarray, CostMatrix],
    alpha: float = 0.1,
    save_results: bool = False
) -> Dict[str, Any]:
    """
    Quick comparison of all baseline methods.
    
    Args:
        probabilities_cal: Calibration probabilities [n_cal, n_classes]
        labels_cal: Calibration labels [n_cal]
        probabilities_test: Test probabilities [n_test, n_classes]
        labels_test: Test labels [n_test]
        cost_matrix: Cost matrix
        alpha: Risk level (default: 0.1)
        save_results: Whether to save results (default: False)
        
    Returns:
        Comparison results
    """
    config = BaselineConfig(alpha=alpha, save_results=save_results, save_plots=save_results)
    comparison = BaselineComparison(config)
    
    return comparison.compare_all_methods(
        probabilities_cal, labels_cal, probabilities_test, labels_test, cost_matrix
    )


def load_baseline_method(filepath: str) -> ConformalBaselineBase:
    """Load a fitted baseline method from file."""
    return ConformalBaselineBase.load(filepath)
