""" plotting figures to show that after overlap 50% things go downhill; show this via harm and relative measures """
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

def plot_harm_detection_from_csv(
    detailed_csv: str,
    marginal_csv: str = None,
    output_dir: str = 'figures',
    plot_selection: str = 'redundancy_efficiency',
    figsize: tuple = (14, 5),
    dpi: int = 300,
    use_pastel: bool = True,
    show_harm_thresholds: bool = True,
    problem_type: str = 'BiTSP',
    problem_size: str = 'large',
    save_formats: list = ['png', 'pdf']
):
    """
    Create publication-ready harm detection plots from CSV data.
    
    Parameters:
    -----------
    detailed_csv : str
        Path to the detailed CSV file (harm_detection_detailed_*.csv)
    marginal_csv : str, optional
        Path to marginal efficiency CSV (marginal_efficiency_*.csv)
        If None, marginal efficiency will be calculated from detailed_csv
    output_dir : str
        Directory to save output figures
    plot_selection : str
        Which plots to include:
        - 'redundancy_efficiency': Redundancy + Marginal efficiency (RECOMMENDED)
        - 'runtime_convergence': Runtime + Convergence
        - 'all': All four plots
    figsize : tuple
        Figure size (width, height) in inches
    dpi : int
        Resolution for saved figures
    use_pastel : bool
        Use pastel color scheme (softer, paper-friendly)
    show_harm_thresholds : bool
        Show dashed lines for harm thresholds
    problem_type : str
        Problem type for title (e.g., 'BiTSP', 'BiKP')
    problem_size : str or int
        Problem size for title (e.g., 'large', 100)
    save_formats : list
        List of formats to save (['png', 'pdf', 'svg'])
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The created figure object
    """
    
    # Create output directory
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Read detailed CSV
    print(f"Reading detailed data from: {detailed_csv}")
    df_detailed = pd.read_csv(detailed_csv)
    
    # Aggregate statistics from detailed data
    print("Aggregating statistics...")
    df_agg = df_detailed.groupby(['decomposition_size', 'overlap_percentage']).agg({
        'hypervolume': ['mean', 'std'],
        'runtime': ['mean', 'std'],
        'wall_time': ['mean', 'std'],
        'num_solutions': ['mean', 'std'],
        'tour_length': ['mean', 'std'],
        'convergence_iter': ['mean', 'std'],
        'empirical_redundancy': ['mean', 'std'],
        'theoretical_redundancy': ['mean', 'std'],
        'n_subproblems': 'mean',
        'total_work': ['mean', 'std']
    }).reset_index()
    
    # Flatten column names
    df_agg.columns = ['_'.join(col).strip('_') for col in df_agg.columns.values]
    
    # Read or calculate marginal efficiency
    if marginal_csv is not None and Path(marginal_csv).exists():
        print(f"Reading marginal efficiency from: {marginal_csv}")
        df_marginal = pd.read_csv(marginal_csv)
    else:
        print("Calculating marginal efficiency from detailed data...")
        df_marginal = calculate_marginal_efficiency(df_agg)
    
    # Get unique decomposition sizes
    decomp_sizes = sorted(df_agg['decomposition_size'].unique())
    
    # Define color palette
    if use_pastel:
        # Pastel colors - soft and paper-friendly
        # pastel_colors = ['#AEC6CF', '#FFB3BA', '#BAFFC9', '#FFD1DC', '#E0BBE4', '#FFDAC1']
        # color_palette = {size: pastel_colors[i % len(pastel_colors)] 
        #                 for i, size in enumerate(decomp_sizes)}
        
        # Darker pastel colors - more visible
        pastel_colors = ['#0173B2', '#CC3311', '#029E73', '#EE7733', '#9933CC', '#33BBEE', '#EE3377']
        color_palette = {size: pastel_colors[i % len(pastel_colors)] 
                        for i, size in enumerate(decomp_sizes)}
    else:
        # Original vibrant colors
        colors = plt.cm.tab10(np.linspace(0, 1, len(decomp_sizes)))
        color_palette = {size: colors[i] for i, size in enumerate(decomp_sizes)}
    
    # Determine subplot layout
    if plot_selection == 'redundancy_efficiency':
        n_cols = 2
        plot_indices = [0, 1]
        plot_titles = [
            'Computational Redundancy',
            'Marginal Benefit Analysis'
        ]
    elif plot_selection == 'runtime_convergence':
        n_cols = 2
        plot_indices = [2, 3]
        plot_titles = [
            'Runtime Cost',
            'Convergence Behavior'
        ]
    elif plot_selection == 'all':
        n_cols = 4
        plot_indices = [0, 1, 2, 3]
        plot_titles = [
            'Computational Redundancy',
            'Marginal Benefit Analysis',
            'Runtime Cost',
            'Convergence Behavior'
        ]
    else:
        raise ValueError(f"Invalid plot_selection: {plot_selection}")
    
    # Create figure
    fig, axes = plt.subplots(1, n_cols, figsize=figsize, constrained_layout=True)
    if n_cols == 1:
        axes = [axes]
    
    # Configure matplotlib for publication quality
    plt.rcParams.update({
        'font.size': 11,
        'axes.labelsize': 13,
        'axes.titlesize': 15,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 11,
        'figure.titlesize': 18,
        'axes.linewidth': 1.2,
        'grid.linewidth': 0.8,
        'lines.linewidth': 2.8,
        'lines.markersize': 8,
    })
    
    # Plot each selected subplot
    for plot_idx, ax_idx in enumerate(plot_indices):
        ax = axes[plot_idx]
        
        # -----------------
        # PLOT 0: Computational Redundancy
        # -----------------
        if ax_idx == 0:
            for decomp_size in decomp_sizes:
                df_subset = df_agg[df_agg['decomposition_size'] == decomp_size].sort_values('overlap_percentage')
                
                x = df_subset['overlap_percentage']
                y = df_subset['empirical_redundancy_mean']
                y_err = df_subset['empirical_redundancy_std']
                
                color = color_palette[decomp_size]
                ax.plot(x, y, 'o-', label=f'D={decomp_size}', 
                       color=color, linewidth=2.8, markersize=8, alpha=0.7)
                ax.fill_between(x, y - y_err, y + y_err, alpha=0.25, color=color)
            
            # Harm threshold
            if show_harm_thresholds:
                ax.axhline(10, color='crimson', linestyle='--', linewidth=2.5, #2.5
                          alpha=0.7, label='Harm Threshold (10×)', zorder=1) # 2.5
            
            ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=13) # fontweight='bold'
            ax.set_ylabel('Redundancy Ratio (Total Work / Problem Size)', 
                         fontsize=13) # fontweight='bold'
            ax.set_title('Computational Redundancy (↓ better)', 
                        fontsize=15, pad=12) # fontweight='bold'
            ax.legend(frameon=True, framealpha=0.95, edgecolor='gray', 
                     loc='upper left', fontsize=11)
            ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
            ax.set_xlim(-5, 105)
            
            # Add vertical lines at harm threshold crossings
            if show_harm_thresholds:
                for decomp_size in [10,20,30]: #decomp_sizes
                    df_subset = df_agg[df_agg['decomposition_size'] == decomp_size].sort_values('overlap_percentage')
                    y = df_subset['empirical_redundancy_mean']
                    x = df_subset['overlap_percentage']
                    
                    # Find where it crosses 2.5
                    crossing_idx = np.where(y.values > 10)[0]
                    if len(crossing_idx) > 0:
                        crossing_overlap = x.values[crossing_idx[0]]
                        ax.axvline(crossing_overlap, color=color_palette[decomp_size], 
                                  linestyle='--', linewidth=1.5, alpha=0.4)
            
            # 50% overlap reference line
            ax.axvline(50, color='dimgray', linestyle='--', 
                    linewidth=2.5, alpha=0.7, zorder=1)  # ← Remove label
            
            # Legend entry for 50% line
            ax.plot([], [], color='dimgray', linestyle='--', linewidth=2.5, 
                        label='50% Overlap')  # ← Single entry
        # -----------------
        # PLOT 1: Marginal Efficiency
        # -----------------
        elif ax_idx == 1:
            for decomp_size in decomp_sizes:
                df_subset = df_marginal[df_marginal['decomposition_size'] == decomp_size].sort_values('overlap_percentage')
                
                if len(df_subset) == 0:
                    continue
                
                x = df_subset['overlap_percentage']
                y = df_subset['marginal_efficiency']
                
                color = color_palette[decomp_size]
                ax.plot(x, y, 'o-', label=f'D={decomp_size}',
                       color=color, linewidth=2.8, markersize=8, alpha=0.7)
            
                # ADD THIS SECTION HERE (before ax.set_xlabel):
                # Add vertical lines for D=10 and D=30 harm thresholds
                # if show_harm_thresholds:
                    # D=10 harm threshold (around 75%)
                    # if 10 in decomp_sizes:
                    #     ax.axvline(75, color=color_palette[10], linestyle='--', 
                    #             linewidth=2, alpha=0.3)
                    
                    # D=30 harm threshold (around 85%)
                    # if 30 in decomp_sizes:
                    #     ax.axvline(85, color=color_palette[30], linestyle='--', 
                    #             linewidth=2, alpha=0.3)
                    
            # # Single legend entry for harm thresholds
            # ax.plot([], [], color='red', linestyle='--', linewidth=2, 
            #     label='Harm Thresholds')  # ← Single entry for both
            
            # 50% overlap reference line
            ax.axvline(50, color='dimgray', linestyle='--', 
                    linewidth=2.5, alpha=0.7, zorder=1)  # ← Remove label
            
            # Legend entry for 50% line
            ax.plot([], [], color='dimgray', linestyle='--', linewidth=2.5, 
                        label='50% Overlap')  # ← Single entry
            
            ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=13) #, fontweight='bold'
            ax.set_ylabel('Marginal Efficiency (ΔHV / ΔRuntime)', 
                         fontsize=13, ) #fontweight='bold'
            ax.set_title('Marginal Benefit Analysis (↑ better)', 
                        fontsize=15, pad=12) #fontweight='bold'
            ax.legend(frameon=True, framealpha=0.95, edgecolor='gray', 
                     loc='best', fontsize=11)
            ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
            ax.set_xlim(3, 105) # -5
            ax.set_yscale('log')
            
            # Add zero line
            ax.axhline(0, color='gray', linestyle='-', linewidth=1.2, alpha=0.4)
        
        # -----------------
        # PLOT 2: Runtime (Absolute, no baseline)
        # -----------------
        elif ax_idx == 2:
            for decomp_size in decomp_sizes:
                df_subset = df_agg[df_agg['decomposition_size'] == decomp_size].sort_values('overlap_percentage')
                
                x = df_subset['overlap_percentage']
                y = df_subset['runtime_mean']
                y_err = df_subset['runtime_std']
                
                color = color_palette[decomp_size]
                ax.plot(x, y, 'o-', label=f'D={decomp_size}',
                       color=color, linewidth=2.8, markersize=8, alpha=0.9)
                ax.fill_between(x, y - y_err, y + y_err, alpha=0.15, color=color)
            
            ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=13) # fontweight='bold'
            ax.set_ylabel('Runtime (seconds)', fontsize=13)# fontweight='bold'
            ax.set_title('Runtime Cost (↓ better)', fontsize=15, pad=12) # fontweight='bold'
            ax.legend(frameon=True, framealpha=0.95, edgecolor='gray', 
                     loc='upper left', fontsize=11)
            ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
            ax.set_xlim(-5, 105)
        
        # -----------------
        # PLOT 3: Convergence (Absolute, no baseline)
        # -----------------
        elif ax_idx == 3:
            for decomp_size in decomp_sizes:
                df_subset = df_agg[df_agg['decomposition_size'] == decomp_size].sort_values('overlap_percentage')
                
                x = df_subset['overlap_percentage']
                y = df_subset['convergence_iter_mean']
                y_err = df_subset['convergence_iter_std']
                
                color = color_palette[decomp_size]
                ax.plot(x, y, 'o-', label=f'D={decomp_size}',
                       color=color, linewidth=2.8, markersize=8, alpha=0.9)
                ax.fill_between(x, y - y_err, y + y_err, alpha=0.15, color=color)
            
            ax.set_xlabel('Overlap (% of Decomposition Size)', fontsize=13) # fontweight='bold'
            ax.set_ylabel('Convergence Iterations', fontsize=13) # fontweight='bold'
            ax.set_title('Convergence Behavior (↓ better)', 
                        fontsize=15, pad=12) # fontweight='bold'
            ax.legend(frameon=True, framealpha=0.95, edgecolor='gray', 
                     loc='best', fontsize=11)
            ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
            ax.set_xlim(-5, 105)
    
    # Overall title
    selection_name = {
        'redundancy_efficiency': 'Computational Cost & Efficiency Analysis',
        'runtime_convergence': 'Runtime & Convergence Analysis',
        'all': 'Complete Harm Detection Analysis'
    }
    
    # Format problem size
    size_str = f"{problem_size}" if isinstance(problem_size, int) else problem_size
    
    # fig.suptitle(
    #     f'Overlap Harm Detection: {selection_name[plot_selection]} ({problem_type}, N={size_str})',
    #     fontsize=18,
    #     # fontweight='bold',
    #     y=1.07
    # )
    
    # Save figure
    for fmt in save_formats:
        output_file = Path(output_dir) / f'harm_detection_{plot_selection}_{problem_type}_{size_str}.{fmt}'
        fig.savefig(output_file, dpi=dpi, bbox_inches='tight', format=fmt)
        print(f"✓ Saved: {output_file}")
    
    return fig


def calculate_marginal_efficiency(df_agg):
    """Calculate marginal efficiency from aggregated data"""
    results = []
    
    for decomp_size in df_agg['decomposition_size'].unique():
        df_subset = df_agg[df_agg['decomposition_size'] == decomp_size].sort_values('overlap_percentage')
        
        x = df_subset['overlap_percentage'].values
        hv = df_subset['hypervolume_mean'].values
        runtime = df_subset['runtime_mean'].values
        
        for i in range(1, len(x)):
            delta_hv = hv[i] - hv[i-1]
            delta_runtime = runtime[i] - runtime[i-1]
            
            if delta_runtime > 1e-10:
                efficiency = delta_hv / delta_runtime
            else:
                efficiency = 0
            
            results.append({
                'decomposition_size': decomp_size,
                'overlap_percentage': x[i],
                'marginal_efficiency': efficiency,
                'delta_hypervolume': delta_hv,
                'delta_runtime': delta_runtime
            })
    
    return pd.DataFrame(results)


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    
    # Set your CSV file paths
    
    # Example 1: Plot Redundancy + Efficiency (RECOMMENDED FOR PAPER)
    print("\n" + "="*80)
    print("EXAMPLE 1: Redundancy + Marginal Efficiency (Recommended for Paper)")
    print("="*80)
    
    fig1 = plot_harm_detection_from_csv(
        detailed_csv=DETAILED_CSV,
        marginal_csv=MARGINAL_CSV,  # Optional, will calculate if not provided
        output_dir='/Users/eshasingh/project_MOCO/harm_detection_results/paper_figures',
        plot_selection='redundancy_efficiency',
        figsize=(14, 5),
        dpi=300,
        use_pastel=True,
        show_harm_thresholds=True,
        problem_type='BiTSP',
        problem_size=100,  # or 'large'
        save_formats=['png', 'pdf']
    )
    
    # Example 2: Plot Runtime + Convergence (Alternative)
    # print("\n" + "="*80)
    # print("EXAMPLE 2: Runtime + Convergence (Alternative)")
    # print("="*80)
    
    # fig2 = plot_harm_detection_from_csv(
    #     detailed_csv=DETAILED_CSV,
    #     marginal_csv=MARGINAL_CSV,
    #     output_dir='paper_figures',
    #     plot_selection='runtime_convergence',
    #     figsize=(14, 5),
    #     dpi=300,
    #     use_pastel=True,
    #     show_harm_thresholds=False,
    #     problem_type='BiTSP',
    #     problem_size=100,
    #     save_formats=['png', 'pdf']
    # )
    
    # # Example 3: All four plots (For appendix or presentation)
    # print("\n" + "="*80)
    # print("EXAMPLE 3: All Four Plots (For Appendix)")
    # print("="*80)
    
    # fig3 = plot_harm_detection_from_csv(
    #     detailed_csv=DETAILED_CSV,
    #     marginal_csv=MARGINAL_CSV,
    #     output_dir='paper_figures',
    #     plot_selection='all',
    #     figsize=(20, 5),
    #     dpi=300,
    #     use_pastel=True,
    #     show_harm_thresholds=True,
    #     problem_type='BiTSP',
    #     problem_size=100,
    #     save_formats=['png', 'pdf', 'svg']
    # )
    
    # plt.show()