#!/usr/bin/env python3
"""
Paper-Quality Visualization Module for CNCRC

This module provides high-quality matplotlib visualizations for the CNCRC paper,
including alpha-performance curves, bridging quantities, score distributions,
and other key figures needed for publication.

Features:
- Academic paper styling (LaTeX fonts, proper sizing)
- Consistent color schemes and markers
- High-resolution output (PDF + PNG)
- Comprehensive figure generation from alpha sweep data
"""

import json
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.ticker import MultipleLocator
import seaborn as sns
from dataclasses import dataclass

# Set up matplotlib for paper-quality figures
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif'],
    'text.usetex': False,  # Set to True if LaTeX is available
    'figure.figsize': (10, 6),
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
    'axes.linewidth': 1.2,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'legend.frameon': True,
    'legend.fancybox': True,
    'legend.shadow': True,
    'legend.framealpha': 0.9
})

# Color scheme for methods (colorblind-friendly)
METHOD_COLORS = {
    'Standard_CP': '#1f77b4',      # Blue
    'CNCRC_Max': '#ff7f0e',       # Orange
    'CNCRC_Sum': '#2ca02c',       # Green
    'Cost_Aware': '#d62728',      # Red
    'Heuristic_CP': '#9467bd',    # Purple
    'UDCP': '#8c564b',            # Brown
    'CRC': '#e377c2'              # Pink
}

METHOD_MARKERS = {
    'Standard_CP': 'o',
    'CNCRC_Max': 's',
    'CNCRC_Sum': '^',
    'Cost_Aware': 'D',
    'Heuristic_CP': 'v',
    'UDCP': '<',
    'CRC': '>'
}

METHOD_LABELS = {
    'Standard_CP': 'Standard CP',
    'CNCRC_Max': 'CNCRC-Max',
    'CNCRC_Sum': 'CNCRC-Sum',
    'Cost_Aware': 'Cost-Aware CP',
    'Heuristic_CP': 'Heuristic CP',
    'UDCP': 'UDCP',
    'CRC': 'CRC'
}


@dataclass
class PlotConfig:
    """Configuration for plot styling and output."""
    output_dir: str = "results/figures"
    figure_format: List[str] = None
    figure_size: Tuple[float, float] = (10, 6)
    dpi: int = 300

    def __post_init__(self):
        if self.figure_format is None:
            self.figure_format = ['pdf', 'png']


class PaperPlotGenerator:
    """Main class for generating paper-quality CNCRC visualizations."""

    def __init__(self, config: PlotConfig = None):
        self.config = config or PlotConfig()
        os.makedirs(self.config.output_dir, exist_ok=True)

    def load_alpha_sweep_data(self, data_path: str) -> Dict[str, Any]:
        """Load alpha sweep experiment results."""
        with open(data_path, 'r') as f:
            return json.load(f)

    def _save_figure(self, fig: plt.Figure, filename: str):
        """Save figure in multiple formats."""
        for fmt in self.config.figure_format:
            filepath = os.path.join(self.config.output_dir, f"{filename}.{fmt}")
            fig.savefig(filepath, format=fmt, dpi=self.config.dpi)
            print(f"Saved: {filepath}")

    def plot_risk_level_performance_curves(self, data: Dict[str, Any],
                                         metrics: List[str] = None) -> plt.Figure:
        """
        Generate risk level-performance curves showing how different methods
        perform across various target risk levels.

        Note: For CNCRC methods, these are target risk levels R₀.
        For CP methods, these correspond to α parameters that achieve similar risk levels.
        Focus only on ambiguity cost since non-coverage risk doesn't show CNCRC advantage.

        Args:
            data: Risk level sweep experiment results
            metrics: List of metrics to plot (default: only 'ambiguity_cost_max')

        Returns:
            matplotlib Figure object
        """
        if metrics is None:
            metrics = ['ambiguity_cost_max']  # Only show ambiguity cost

        # Organize data by method and risk level
        risk_levels_all = data['alpha_values']  # These are actually target risk levels
        # Filter out risk levels > 0.15 per guidance
        risk_levels = [r for r in risk_levels_all if r <= 0.15]
        # Focus on CNCRC_Sum vs baselines (CNCRC_Sum first for legend order)
        methods = ['CNCRC_Sum', 'Standard_CP', 'Cost_Aware']
        results = data['results']

        # Create data structure: method -> risk_level -> metric -> value
        method_data = {method: {risk_level: {} for risk_level in risk_levels} for method in methods}

        for result in results:
            method = result['method']
            risk_level = result['alpha']  # This field contains the risk level/α value
            # Skip methods we're not interested in
            if method not in methods:
                continue
            if risk_level not in method_data[method]:
                continue
            for metric in metrics:
                if metric in result:
                    method_data[method][risk_level][metric] = result[metric]

        # Create single plot for ambiguity cost
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))

        metric_titles = {
            'ambiguity_cost_max': 'Ambiguity Cost (Max)',
        }

        metric = metrics[0]  # Only one metric now

        # Plot each method
        for method in methods:
            risk_vals = []
            values = []

            for risk_level in risk_levels:
                if metric in method_data[method][risk_level]:
                    risk_vals.append(risk_level)
                    values.append(method_data[method][risk_level][metric])

            if risk_vals and values:
                # Create appropriate label based on method type
                if method.startswith('CNCRC'):
                    method_label = f"{METHOD_LABELS.get(method, method)} (R₀)"
                else:
                    method_label = f"{METHOD_LABELS.get(method, method)} (α)"

                ax.plot(risk_vals, values,
                       color=METHOD_COLORS.get(method, 'black'),
                       marker=METHOD_MARKERS.get(method, 'o'),
                       linewidth=2.5,
                       markersize=8,
                       label=method_label,
                       alpha=0.8)

        ax.set_xlabel('Risk Level (R₀ for CNCRC, α for CP)', fontsize=12)
        ax.set_ylabel(metric_titles.get(metric, metric), fontsize=12)
        ax.set_title(f'{metric_titles.get(metric, metric)} vs Risk Level', fontsize=14, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        self._save_figure(fig, 'risk_level_performance_curves')
        return fig


    def plot_score_distributions(self, data: Dict[str, Any], risk_level_focus: float = 0.1) -> plt.Figure:
        """
        Plot score distributions and thresholds for different methods at a specific risk level.
        This helps understand why CNCRC works better.

        Args:
            risk_level_focus: Target risk level to focus on (R₀ for CNCRC, α for CP)
        """
        results = data['results']
        calibration_scores = data['calibration_scores']

        # Filter results for the focused risk level and focus on key methods
        risk_results = [r for r in results if abs(r['alpha'] - risk_level_focus) < 0.001
                       and r['method'] in ['CNCRC_Sum', 'Standard_CP', 'Cost_Aware']]

        if not risk_results:
            raise ValueError(f"No results found for risk level = {risk_level_focus}")

        # Create subplots - only 3 methods now, with unified aspect ratio
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # Set unified aspect ratio for all subplots
        for ax in axes:
            ax.set_aspect('auto')

        methods = [r['method'] for r in risk_results]

        for i, method in enumerate(methods):
            ax = axes[i]

            # Get score distribution and threshold
            scores = calibration_scores[method]
            result = next(r for r in risk_results if r['method'] == method)
            threshold = result['q_threshold']

            # Plot histogram
            ax.hist(scores, bins=30, alpha=0.7, density=True,
                   color=METHOD_COLORS.get(method, 'gray'),
                   edgecolor='black', linewidth=0.5)

            # Add threshold line
            ax.axvline(threshold, color='red', linestyle='--', linewidth=2,
                      label=f'Threshold q = {threshold:.3f}')

            # Shade accept region (score ≤ q) with method-specific emphasis
            if method == 'CNCRC_Sum':
                # Emphasize CNCRC's larger accept region
                ax.axvspan(min(scores), threshold, color='green', alpha=0.25,
                          label='Accept region (score ≤ q) - LARGER!')
            else:
                ax.axvspan(min(scores), threshold, color='green', alpha=0.12,
                          label='Accept region (score ≤ q)')

            # Add statistics
            mean_score = np.mean(scores)
            ax.axvline(mean_score, color='blue', linestyle=':', linewidth=2,
                      label=f'Mean = {mean_score:.3f}')

            # Enhanced metrics annotation to highlight CNCRC advantages
            rnc = result.get('non_coverage_risk', None)
            amb = result.get('ambiguity_cost_max', None)
            if rnc is not None and amb is not None:
                # Calculate accept region size (percentage of samples with score ≤ threshold)
                accept_ratio = np.mean(np.array(scores) <= threshold)

                # Create comprehensive annotation
                annotation_text = f'R_NC={rnc:.3f}\nAmbCost={amb:.3f}\nAccept={accept_ratio:.1%}'

                # Add interpretation for all methods with unified format
                if method == 'CNCRC_Sum':
                    annotation_text += '\n\n• More samples\n  in safe region\n• Lower cost\n  through risk-\n  aware scoring'
                    box_color = 'lightgreen'
                elif method == 'Standard_CP':
                    annotation_text += '\n\n• Probability-\n  based only\n• Higher cost'
                    box_color = 'lightblue'
                elif method == 'Cost_Aware':
                    annotation_text += '\n\n• Cost-aware\n  but simple\n• Moderate\n  performance'
                    box_color = 'lightyellow'
                else:
                    box_color = 'lightgray'

                ax.text(1.02, 0.5,
                        annotation_text,
                        transform=ax.transAxes, ha='left', va='center', fontsize=9,
                        bbox=dict(boxstyle='round,pad=0.4', facecolor=box_color, alpha=0.9))

            ax.set_xlabel('Non-conformity Score (lower is better; shaded=accepted)')
            ax.set_ylabel('Density')
            # Create appropriate title based on method type
            if method.startswith('CNCRC'):
                title = f'{METHOD_LABELS.get(method, method)} Score Distribution\n(R₀ = {risk_level_focus})'
            else:
                title = f'{METHOD_LABELS.get(method, method)} Score Distribution\n(α = {risk_level_focus})'
            ax.set_title(title,
                        fontweight='bold')
            ax.legend()
            ax.grid(True, alpha=0.3)

        # Only unify y-axis limits for density comparison, keep x-axis independent
        # since different methods have different score definitions and ranges
        all_ylims = [ax.get_ylim() for ax in axes]
        max_ylim = max([ylim[1] for ylim in all_ylims])

        # Apply unified y-axis limits only
        for ax in axes:
            ax.set_ylim(0, max_ylim)

        plt.tight_layout()
        self._save_figure(fig, f'score_distributions_risk_level_{risk_level_focus}')
        return fig

    def plot_calibration_quality(self, data: Dict[str, Any]) -> plt.Figure:
        """
        Plot calibration quality comparison showing |R_NC - α| for different methods.
        This shows the trade-off between cost efficiency and calibration accuracy.
        """
        results = data['results']
        alpha_values = sorted(set(r['alpha'] for r in results))
        # Focus on alpha range 0.05-0.15 for clearer behavior
        alpha_values = [a for a in alpha_values if a <= 0.15]
        methods = ['CNCRC_Sum', 'Standard_CP', 'Cost_Aware']

        # Calculate calibration gaps for each method
        method_gaps = {method: {'alphas': [], 'gaps': [], 'costs': []} for method in methods}

        for result in results:
            method = result['method']
            if method not in methods or result['alpha'] not in alpha_values:
                continue

            alpha = result['alpha']
            rnc = result['non_coverage_risk']
            amb_cost = result['ambiguity_cost_max']
            gap = abs(rnc - alpha)

            method_gaps[method]['alphas'].append(alpha)
            method_gaps[method]['gaps'].append(gap)
            method_gaps[method]['costs'].append(amb_cost)

        # Create dual-axis plot
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        # Plot 1: Calibration gaps
        for method in methods:
            if method_gaps[method]['gaps']:
                ax1.plot(method_gaps[method]['alphas'], method_gaps[method]['gaps'],
                        color=METHOD_COLORS.get(method, 'black'),
                        marker=METHOD_MARKERS.get(method, 'o'),
                        linewidth=2.5, markersize=8,
                        label=METHOD_LABELS.get(method, method))

        ax1.set_xlabel('Target Risk Level (α)', fontsize=12)
        ax1.set_ylabel('Calibration Error |R_NC - α|', fontsize=12)
        ax1.set_title('Calibration Quality Comparison', fontweight='bold', fontsize=14)
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Add annotation about trade-off
        ax1.text(0.02, 0.98, 'Lower is better\n(Perfect calibration = 0)',
                transform=ax1.transAxes, fontsize=10, va='top',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.8))

        # Plot 2: Cost vs Calibration scatter
        for method in methods:
            if method_gaps[method]['gaps']:
                avg_gap = np.mean(method_gaps[method]['gaps'])
                avg_cost = np.mean(method_gaps[method]['costs'])

                ax2.scatter(avg_gap, avg_cost,
                           color=METHOD_COLORS.get(method, 'black'),
                           marker=METHOD_MARKERS.get(method, 'o'),
                           s=150, alpha=0.8,
                           label=METHOD_LABELS.get(method, method))

                # Add method name annotation
                ax2.annotate(METHOD_LABELS.get(method, method),
                           (avg_gap, avg_cost),
                           xytext=(5, 5), textcoords='offset points',
                           fontsize=10, ha='left')

        ax2.set_xlabel('Average Calibration Error', fontsize=12)
        ax2.set_ylabel('Average Ambiguity Cost', fontsize=12)
        ax2.set_title('Cost-Calibration Trade-off', fontweight='bold', fontsize=14)
        ax2.grid(True, alpha=0.3)

        # Add ideal region annotation
        ax2.text(0.02, 0.98, 'Ideal: Lower-left\n(Low error + Low cost)',
                transform=ax2.transAxes, fontsize=10, va='top',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', alpha=0.8))

        plt.tight_layout()
        self._save_figure(fig, 'calibration_quality')
        return fig

    def plot_set_size_distributions(self, data: Dict[str, Any]) -> plt.Figure:
        """
        Compare prediction set size distributions across methods using violin plots.
        """
        results = data['results']
        methods = data['methods']

        # Organize set sizes by method
        method_set_sizes = {method: [] for method in methods}

        for result in results:
            method = result['method']
            if 'set_sizes' in result and result['set_sizes']:
                method_set_sizes[method].extend(result['set_sizes'])

        # Prepare data for violin plot
        plot_data = []
        plot_labels = []

        for method in methods:
            if method_set_sizes[method]:
                plot_data.append(method_set_sizes[method])
                plot_labels.append(METHOD_LABELS.get(method, method))

        # Create violin plot
        fig, ax = plt.subplots(figsize=(12, 8))

        parts = ax.violinplot(plot_data, positions=range(len(plot_data)),
                             showmeans=True, showmedians=True)

        # Color the violins
        for i, pc in enumerate(parts['bodies']):
            method = methods[i]
            pc.set_facecolor(METHOD_COLORS.get(method, 'gray'))
            pc.set_alpha(0.7)

        # Customize plot
        ax.set_xticks(range(len(plot_labels)))
        ax.set_xticklabels(plot_labels, rotation=45, ha='right')
        ax.set_ylabel('Prediction Set Size')
        ax.set_title('Prediction Set Size Distributions by Method', fontweight='bold', fontsize=14)
        ax.grid(True, alpha=0.3)

        # Add summary statistics
        for i, method in enumerate(methods):
            if method_set_sizes[method]:
                sizes = method_set_sizes[method]
                mean_size = np.mean(sizes)
                median_size = np.median(sizes)
                ax.text(i, max(sizes) * 0.9, f'μ={mean_size:.1f}\nM={median_size:.1f}',
                       ha='center', va='top', fontsize=10,
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

        plt.tight_layout()
        self._save_figure(fig, 'set_size_distributions')
        return fig

    def plot_risk_risk_tradeoff(self, data: Dict[str, Any]) -> plt.Figure:
        """
        Create risk-risk tradeoff plot showing the relationship between
        non-coverage risk and ambiguity cost across methods.
        X-axis: Non-coverage risk (R_NC), Y-axis: Ambiguity cost (AmbCost)
        """
        results = data['results']
        # Focus on CNCRC_Sum vs baselines (CNCRC_Sum first for legend order)
        methods = ['CNCRC_Sum', 'Standard_CP', 'Cost_Aware']

        # Organize data by method
        method_points = {method: {'nc_risk': [], 'amb_cost': []} for method in methods}

        for result in results:
            method = result['method']
            # Skip methods we're not interested in
            if method not in methods:
                continue
            nc_risk = result['non_coverage_risk']
            amb_cost = result['ambiguity_cost_max']
            method_points[method]['nc_risk'].append(nc_risk)
            method_points[method]['amb_cost'].append(amb_cost)

        # Create scatter plot
        fig, ax = plt.subplots(figsize=(12, 8))

        # Filter to alpha <= 0.15 for clearer behavior
        alpha_cut = 0.15

        for method in methods:
            # Collect filtered points by alpha
            points = []
            for r in results:
                if r['method'] == method and r['alpha'] <= alpha_cut:
                    points.append((r['non_coverage_risk'], r['ambiguity_cost_max']))

            if points:
                xs, ys = zip(*points)
                ax.scatter(xs, ys,
                          color=METHOD_COLORS.get(method, 'gray'),
                          marker=METHOD_MARKERS.get(method, 'o'),
                          s=100, alpha=0.75,
                          label=METHOD_LABELS.get(method, method))
                # Connect points after sorting by x (risk)
                combined = sorted(points, key=lambda x: x[0])
                xs_line, ys_line = zip(*combined)
                ax.plot(xs_line, ys_line,
                        color=METHOD_COLORS.get(method, 'gray'),
                        alpha=0.5, linewidth=1.5)

        ax.set_xlabel('Non-Coverage Risk (R_NC)', fontsize=12)
        ax.set_ylabel('Ambiguity Cost (AmbCost)', fontsize=12)
        ax.set_title('Risk-Risk Tradeoff: Non-Coverage vs Ambiguity', fontweight='bold', fontsize=14)
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.grid(True, alpha=0.3)

        # Adjust axis limits to reduce empty space in lower-left
        all_nc = [r['non_coverage_risk'] for r in results if r['alpha'] <= alpha_cut]
        all_amb = [r['ambiguity_cost_max'] for r in results if r['alpha'] <= alpha_cut]
        if all_nc and all_amb:
            min_nc_risk = min(all_nc)
            max_nc_risk = max(all_nc)
            min_amb_cost = min(all_amb)
            max_amb_cost = max(all_amb)

            # Set axis limits with small margins
            nc_margin = (max_nc_risk - min_nc_risk) * 0.05
            amb_margin = (max_amb_cost - min_amb_cost) * 0.05

            ax.set_xlim(max(0, min_nc_risk - nc_margin), max_nc_risk + nc_margin)
            ax.set_ylim(max(0, min_amb_cost - amb_margin), max_amb_cost + amb_margin)

            # Add annotation in lower-left area pointing toward origin
            ax.annotate('Closer to origin is better',
                       xy=(min_nc_risk, min_amb_cost),
                       xytext=(min_nc_risk + nc_margin*3, min_amb_cost + amb_margin*3),
                       arrowprops=dict(arrowstyle='->', color='black', lw=1.5),
                       fontsize=11, color='black',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgray', alpha=0.8))

        plt.tight_layout()
        self._save_figure(fig, 'risk_risk_tradeoff')
        return fig

    def plot_risk_control_distribution(self, data: Dict[str, Any], alpha_focus: float = 0.1) -> plt.Figure:
        """
        Plot empirical cumulative distribution function (ECDF) of per-sample non-coverage costs
        to demonstrate risk control effectiveness across methods.
        """
        results = data['results']
        test_samples = data['test_samples']
        cost_nc_vector = np.array(data['cost_nc_vector'])

        # Filter results for the focused alpha value and focus on key methods
        alpha_results = [r for r in results if abs(r['alpha'] - alpha_focus) < 0.001
                        and r['method'] in ['CNCRC_Sum', 'Standard_CP', 'Cost_Aware']]

        if not alpha_results:
            raise ValueError(f"No results found for alpha = {alpha_focus}")

        fig, ax = plt.subplots(figsize=(12, 8))

        # 改用阶梯式ECDF（更直观）
        for result in alpha_results:
            method = result['method']

            # Get method parameters
            q_threshold = result['q_threshold']
            method_params = result.get('method_params', {})

            # Calculate per-sample costs
            sample_costs = []

            for sample in test_samples:
                probs = np.array(sample['probs'])
                y_true = sample['y']

                # Build prediction set based on method
                if method == 'Standard_CP':
                    pred_set = [i for i in range(len(probs)) if (1.0 - probs[i]) <= q_threshold]
                elif method == 'CNCRC_Max':
                    pred_set = []
                    for y in range(len(probs)):
                        from src.cncrc.core.risk_weighted_score import calculate_risk_weighted_score
                        cost_mat = np.array(data['cost_matrix'])
                        score = calculate_risk_weighted_score(probs, cost_mat, y)
                        if score <= q_threshold:
                            pred_set.append(y)
                    if not pred_set:
                        pred_set = [np.argmax(probs)]
                elif method == 'CNCRC_Sum':
                    pred_set = []
                    cost_mat = np.array(data['cost_matrix'])
                    for y in range(len(probs)):
                        costs_for_y = cost_mat[y_true, :]
                        weighted_risks = probs * costs_for_y
                        weighted_risks[y_true] = 0.0
                        score = float(np.sum(weighted_risks))
                        if score <= q_threshold:
                            pred_set.append(y)
                    if not pred_set:
                        pred_set = [np.argmax(probs)]
                elif method == 'Cost_Aware':
                    pred_set = []
                    cost_mat = np.array(data['cost_matrix'])
                    lambda_param = method_params.get('lambda', 0.1)
                    for y in range(len(probs)):
                        uncertainty = 1.0 - probs[y]
                        cost_penalty = cost_nc_vector[y] + np.max(cost_mat[y, :])
                        score = uncertainty + lambda_param * cost_penalty
                        if score <= q_threshold:
                            pred_set.append(y)
                    if not pred_set:
                        pred_set = [np.argmax(probs)]

                # Calculate cost for this sample
                sample_cost = 0.0 if y_true in pred_set else float(cost_nc_vector[y_true])
                sample_costs.append(sample_cost)

            # Step ECDF
            sample_costs = np.array(sample_costs)
            sorted_costs = np.sort(sample_costs)
            ecdf_y = np.arange(1, len(sorted_costs) + 1) / len(sorted_costs)

            ax.step(sorted_costs, ecdf_y,
                    where='post',
                    color=METHOD_COLORS.get(method, 'gray'),
                    linewidth=2.2,
                    label=f'{METHOD_LABELS.get(method, method)} (E[Cost]={np.mean(sample_costs):.3f})',
                    alpha=0.9)

        ax.set_xlabel('Per-Sample Non-Coverage Cost', fontsize=12)
        ax.set_ylabel('Cumulative Probability', fontsize=12)
        ax.set_title(f'Risk Control Distribution (ECDF) at α = {alpha_focus} (越往左、越往上越好)', fontweight='bold', fontsize=14)
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 1)

        # Add reference lines
        ax.axhline(y=0.9, color='red', linestyle='--', alpha=0.4)
        ax.axhline(y=0.95, color='orange', linestyle='--', alpha=0.4)
        ax.text(ax.get_xlim()[0], 0.905, '90%', color='red', fontsize=10, va='bottom')
        ax.text(ax.get_xlim()[0], 0.955, '95%', color='orange', fontsize=10, va='bottom')

        plt.tight_layout()
        self._save_figure(fig, f'risk_control_distribution_alpha_{alpha_focus}')
        return fig

    def generate_all_plots(self, data_path: str) -> Dict[str, plt.Figure]:
        """
        Generate optimized paper plots focusing on risk analysis.

        Args:
            data_path: Path to alpha sweep results JSON file

        Returns:
            Dictionary mapping plot names to Figure objects
        """
        print(f"Loading data from: {data_path}")
        data = self.load_alpha_sweep_data(data_path)

        figures = {}

        print("Generating risk level-performance curves...")
        figures['risk_level_performance'] = self.plot_risk_level_performance_curves(data)

        print("Generating risk-risk tradeoff...")
        figures['risk_risk_tradeoff'] = self.plot_risk_risk_tradeoff(data)

        print("Generating CNCRC mechanism explanation...")

        print("Generating improved score distributions...")
        figures['score_distributions'] = self.plot_score_distributions(data, risk_level_focus=0.1)

        print(f"\nAll optimized plots saved to: {self.config.output_dir}")
        return figures


def main():
    """Example usage of the PaperPlotGenerator."""
    # Configure plot generation
    config = PlotConfig(
        output_dir="results/figures",
        figure_format=['pdf', 'png'],
        dpi=300
    )

    # Create generator
    generator = PaperPlotGenerator(config)

    # Generate all plots
    data_path = "results/alpha_sweep/alpha_sweep_results.json"

    if os.path.exists(data_path):
        figures = generator.generate_all_plots(data_path)
        print(f"Generated {len(figures)} figures successfully!")
    else:
        print(f"Data file not found: {data_path}")
        print("Please run the alpha sweep experiment first.")


if __name__ == "__main__":
    main()