#!/usr/bin/env python3
"""
修正的论文图表生成器
移除概念错误的函数，保留科学正确的可视化
"""

import json
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.ticker import MultipleLocator
import seaborn as sns
from dataclasses import dataclass

# Set up matplotlib for paper-quality figures
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif'],
    'text.usetex': False,
    'figure.figsize': (10, 6),
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
    'axes.linewidth': 1.2,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'legend.frameon': True,
    'legend.fancybox': True,
    'legend.shadow': True,
    'legend.framealpha': 0.9
})

# Color scheme for methods (colorblind-friendly)
METHOD_COLORS = {
    'CNCRC_Sum': '#2E8B57',      # Sea Green
    'CNCRC_Max': '#4682B4',      # Steel Blue
    'Standard_CP': '#DC143C',    # Crimson
    'Cost_Aware': '#FF8C00',     # Dark Orange
    'CRC': '#9932CC'             # Dark Violet
}

METHOD_LABELS = {
    'CNCRC_Sum': 'CNCRC-Sum',
    'CNCRC_Max': 'CNCRC-Max',
    'Standard_CP': 'Standard CP',
    'Cost_Aware': 'Cost-Aware CP',
    'CRC': 'CRC'
}

METHOD_MARKERS = {
    'CNCRC_Sum': 'o',
    'CNCRC_Max': 's',
    'Standard_CP': '^',
    'Cost_Aware': 'D',
    'CRC': 'v'
}

class CorrectedPaperPlotGenerator:
    """
    修正的论文图表生成器
    只包含科学正确的可视化函数
    """

    def __init__(self, output_dir: str = "results/figures"):
        """Initialize the plot generator."""
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def _save_figure(self, fig: plt.Figure, name: str) -> None:
        """Save figure in both PDF and PNG formats."""
        pdf_path = self.output_dir / f"{name}.pdf"
        png_path = self.output_dir / f"{name}.png"

        fig.savefig(pdf_path, format='pdf')
        fig.savefig(png_path, format='png')

        print(f"Saved: {pdf_path}")
        print(f"Saved: {png_path}")

    def plot_risk_risk_tradeoff(self, data: Dict[str, Any], alpha_cut: float = 0.15) -> plt.Figure:
        """
        Plot risk-risk tradeoff showing non-coverage risk vs ambiguity cost.
        This visualization is scientifically correct as it uses actual R_NC values.
        """
        results = data['results']

        # Focus on key methods and reasonable alpha range
        methods = ['CNCRC_Sum', 'CNCRC_Max', 'Standard_CP', 'Cost_Aware']

        fig, ax = plt.subplots(figsize=(10, 8))

        # Plot trajectories for each method
        for method in methods:
            method_results = [r for r in results if r['method'] == method and r['alpha'] <= alpha_cut]

            if not method_results:
                continue

            # Sort by alpha for smooth trajectory
            method_results.sort(key=lambda x: x['alpha'])

            nc_risks = [r['non_coverage_risk'] for r in method_results]
            amb_costs = [r['ambiguity_cost_max'] for r in method_results]

            ax.plot(nc_risks, amb_costs,
                   color=METHOD_COLORS.get(method, 'black'),
                   marker=METHOD_MARKERS.get(method, 'o'),
                   linewidth=2.5, markersize=8,
                   label=METHOD_LABELS.get(method, method),
                   markerfacecolor='white', markeredgewidth=2,
                   alpha=0.8)

        ax.set_xlabel('Non-Coverage Risk (R_NC)', fontsize=12)
        ax.set_ylabel('Ambiguity Cost (AmbCost)', fontsize=12)
        ax.set_title('Risk-Risk Tradeoff: Non-Coverage vs Ambiguity', fontweight='bold', fontsize=14)
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.grid(True, alpha=0.3)

        # Adjust axis limits to reduce empty space in lower-left
        all_nc = [r['non_coverage_risk'] for r in results if r['alpha'] <= alpha_cut]
        all_amb = [r['ambiguity_cost_max'] for r in results if r['alpha'] <= alpha_cut]
        if all_nc and all_amb:
            min_nc_risk = min(all_nc)
            max_nc_risk = max(all_nc)
            min_amb_cost = min(all_amb)
            max_amb_cost = max(all_amb)

            # Set axis limits with small margins
            nc_margin = (max_nc_risk - min_nc_risk) * 0.05
            amb_margin = (max_amb_cost - min_amb_cost) * 0.05

            ax.set_xlim(max(0, min_nc_risk - nc_margin), max_nc_risk + nc_margin)
            ax.set_ylim(max(0, min_amb_cost - amb_margin), max_amb_cost + amb_margin)

            # Add annotation in lower-left area pointing toward origin
            ax.annotate('Closer to origin is better',
                       xy=(min_nc_risk, min_amb_cost),
                       xytext=(min_nc_risk + nc_margin*3, min_amb_cost + amb_margin*3),
                       arrowprops=dict(arrowstyle='->', color='black', lw=1.5),
                       fontsize=11, color='black',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgray', alpha=0.8))

        plt.tight_layout()
        self._save_figure(fig, 'risk_risk_tradeoff')
        return fig

    def create_rnc_aligned_score_distributions(self, data: Dict[str, Any], target_rnc: float = 0.08) -> plt.Figure:
        """
        Create R_NC-aligned score distributions plot.
        This replaces the conceptually flawed alpha-based comparison.
        """
        results = data['results']
        calibration_scores = data['calibration_scores']

        # Find results closest to target R_NC for each method
        methods = ['CNCRC_Sum', 'CNCRC_Max', 'Standard_CP', 'Cost_Aware']
        selected_results = {}
        tolerance = 0.03

        for method in methods:
            method_results = [r for r in results if r['method'] == method]

            # Find result closest to target R_NC
            best_result = min(method_results,
                             key=lambda r: abs(r['non_coverage_risk'] - target_rnc))

            if abs(best_result['non_coverage_risk'] - target_rnc) <= tolerance:
                selected_results[method] = best_result
            else:
                print(f"Warning: {method} has no result close to R_NC={target_rnc}")

        if len(selected_results) < 4:
            print("Error: Not enough methods have results close to target R_NC")
            return None

        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(20, 12))
        axes = axes.flatten()  # Flatten for easier indexing

        for i, method in enumerate(methods):
            if method not in selected_results:
                continue

            ax = axes[i]
            result = selected_results[method]

            # Get score distribution and threshold
            scores = calibration_scores[method]
            threshold = result['q_threshold']
            actual_rnc = result['non_coverage_risk']
            ambiguity_cost = result['ambiguity_cost_max']

            # Plot histogram
            ax.hist(scores, bins=30, alpha=0.7, density=True,
                   color=METHOD_COLORS[method],
                   edgecolor='black', linewidth=0.5)

            # Add threshold line
            ax.axvline(threshold, color='red', linestyle='--', linewidth=2,
                      label=f'Threshold q = {threshold:.3f}')

            # Shade accept region
            if method == 'CNCRC_Sum':
                ax.axvspan(min(scores), threshold, color='green', alpha=0.25,
                          label='Accept region (score ≤ q) - LARGER!')
            else:
                ax.axvspan(min(scores), threshold, color='green', alpha=0.12,
                          label='Accept region (score ≤ q)')

            # Add mean line
            mean_score = np.mean(scores)
            ax.axvline(mean_score, color='blue', linestyle=':', linewidth=2,
                      label=f'Mean = {mean_score:.3f}')

            # Calculate accept ratio
            accept_ratio = np.mean(np.array(scores) <= threshold)

            # Create annotation text
            annotation_text = f'R_NC={actual_rnc:.3f}\nAmbCost={ambiguity_cost:.3f}\nAccept={accept_ratio:.1%}'

            # Add method-specific explanations
            if method == 'CNCRC_Sum':
                annotation_text += '\n\n• Risk-aware\n  scoring\n• Lower cost\n• Larger accept\n  region'
                box_color = 'lightgreen'
            elif method == 'Standard_CP':
                annotation_text += '\n\n• Probability-\n  based only\n• Higher cost\n• Smaller accept\n  region'
                box_color = 'lightblue'
            elif method == 'CNCRC_Max':
                annotation_text += '\n\n• Max-based\n  risk scoring\n• Alternative\n  aggregation'
                box_color = 'lightcyan'
            elif method == 'Cost_Aware':
                annotation_text += '\n\n• Simple cost\n  awareness\n• Moderate\n  performance'
                box_color = 'lightyellow'

            ax.text(1.02, 0.5, annotation_text,
                    transform=ax.transAxes, ha='left', va='center', fontsize=9,
                    bbox=dict(boxstyle='round,pad=0.4', facecolor=box_color, alpha=0.9))

            # Set labels and title - CORRECTED: All use R_NC
            ax.set_xlabel('Non-conformity Score (lower is better)')
            ax.set_ylabel('Density')
            ax.set_title(f'{METHOD_LABELS[method]} Score Distribution\n(R_NC ≈ {actual_rnc:.3f})',
                        fontweight='bold')

            ax.legend()
            ax.grid(True, alpha=0.3)

        # Hide the last subplot if not used (we have 4 methods, so all are used)
        # But if there are fewer than 4 methods, hide unused subplots
        for i in range(len(methods), len(axes)):
            axes[i].set_visible(False)

        # Unify y-axis ranges for visible plots
        visible_axes = axes[:len(methods)]
        all_ylims = [ax.get_ylim() for ax in visible_axes]
        max_ylim = max([ylim[1] for ylim in all_ylims])

        for ax in visible_axes:
            ax.set_ylim(0, max_ylim)

        # Add overall title
        fig.suptitle(f'Score Distributions at Similar Non-Coverage Risk Levels\n(Target R_NC ≈ {target_rnc:.2f})',
                    fontsize=16, fontweight='bold', y=1.02)

        plt.tight_layout()
        self._save_figure(fig, 'score_distributions_rnc_aligned')

        # Print selection info
        print(f"Selected results for R_NC ≈ {target_rnc}:")
        for method, result in selected_results.items():
            print(f"  {method}: α={result['alpha']:.3f}, R_NC={result['non_coverage_risk']:.3f}")

        return fig

    def generate_corrected_plots(self, data_path: str) -> Dict[str, plt.Figure]:
        """
        Generate only the scientifically correct plots.
        """
        print(f"Loading data from: {data_path}")
        with open(data_path, 'r') as f:
            data = json.load(f)

        figures = {}

        print("Generating risk-risk tradeoff...")
        figures['risk_risk_tradeoff'] = self.plot_risk_risk_tradeoff(data)

        print("Generating R_NC-aligned score distributions...")
        figures['score_distributions'] = self.create_rnc_aligned_score_distributions(data)

        print(f"\nAll corrected plots saved to: {self.output_dir}")
        return figures

def main():
    """Generate corrected paper plots."""
    generator = CorrectedPaperPlotGenerator()

    data_path = 'results/alpha_sweep_optimized/alpha_sweep_results_optimized.json'
    figures = generator.generate_corrected_plots(data_path)

    print("\n✅ Corrected paper plots generated successfully!")
    print("📊 Generated plots:")
    print("  1. risk_risk_tradeoff.pdf/png - Risk-risk tradeoff (对比 CNCRC-Sum/Max, Standard CP, Costa_Aware CP)")
    print("  2. score_distributions_rnc_aligned.pdf/png - R_NC-aligned score distributions (4种方法对比)")

if __name__ == "__main__":
    main()