#!/usr/bin/env python3
"""
Replot ablation study results from CSV files.

This script loads saved CSV results (which avoid YAML numpy serialization issues)
and regenerates plots with optional baseline solver reference lines.

The aggregated CSV file contains columns:
- algorithm, problem_type, problem_size, decomposition_size, overlap_ratio
- overlap_absolute, overlap_percentage, n_runs
- hv_mean, hv_std, hv_min, hv_max, hv_median
- runtime_mean, runtime_std, runtime_min, runtime_max
- solutions_mean, solutions_std
- tour_length_mean, tour_length_std
- timestamp
"""

import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional
from collections import defaultdict


def setup_icml_style():
    """Configure matplotlib for ICML paper style"""
    
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("husl")
    
    plt.rcParams.update({
        # Font settings
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica', 'Liberation Sans'],
        'font.size': 10,
        
        # Axis labels - BOLD and LARGER
        'axes.labelsize': 13,          # Increased from 10
        'axes.labelweight': 'bold',    # Make axis labels bold
        
        # Title
        'axes.titlesize': 12,
        'axes.titleweight': 'bold',
        
        # Tick labels
        'xtick.labelsize': 10,         # Increased from 8
        'ytick.labelsize': 10,         # Increased from 8
        
        # Legend
        'legend.fontsize': 9,          # Slightly increased
        
        # Figure title
        'figure.titlesize': 14,
        'figure.titleweight': 'bold',
        
        # Figure settings
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        
        # Layout
        'figure.constrained_layout.use': True,
        'figure.autolayout': False,
        
        # Grid and spines
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 0.8,
        
        # Lines and markers
        'lines.linewidth': 1.5,
        'lines.markersize': 6,
        
        # Legend
        'legend.frameon': True,
        'legend.framealpha': 0.9,
        'legend.edgecolor': '0.8',
        
        # Other
        'axes.axisbelow': True,
    })


# Initialize style
setup_icml_style()


def load_and_replot_from_csv(
    csv_file: str,
    plot_metrics: List[str] = None,
    baseline_solvers: Dict[str, Dict[str, float]] = None,
    output_dir: str = None,
    output_suffix: str = 'replot'
):
    """
    Load saved results from CSV file and regenerate plots.

    This function uses the aggregated CSV file (not individual runs) which contains
    mean/std statistics for each configuration.

    Parameters:
    -----------
    csv_file : str
        Path to the aggregated CSV file (e.g., 'ablation_aggregated_BiTSP_medium_*.csv')
    plot_metrics : List[str], optional
        Metrics to plot. Options: ['hypervolume', 'runtime', 'solutions', 'tour_length']
        Default: ['hypervolume', 'runtime', 'tour_length']
    baseline_solvers : Dict[str, Dict[str, float]], optional
        Baseline solver results to show as reference lines.
        Example: {'WS-LKH': {'hypervolume': 0.625, 'runtime': 1.8}}
    output_dir : str, optional
        Directory to save new plots. If None, uses same directory as CSV file
    output_suffix : str
        Suffix to add to plot filenames (default: 'replot')

    Returns:
    --------
    pd.DataFrame : The loaded data

    Example:
    --------
    >>> baseline_solvers = {
    ...     'WS-LKH': {'hypervolume': 0.625, 'runtime': 1.8},
    ...     'NSGA-II': {'hypervolume': 0.48, 'runtime': 8.5}
    ... }
    >>> load_and_replot_from_csv(
    ...     'ablation_results/ablation_aggregated_BiTSP_small_20251228-133548.csv',
    ...     plot_metrics=['hypervolume', 'runtime', 'tour_length'],
    ...     baseline_solvers=baseline_solvers
    ... )
    """

    print("\n" + "="*80)
    print("LOADING CSV RESULTS AND REGENERATING PLOTS")
    print("="*80)

    # Default metrics
    if plot_metrics is None:
        plot_metrics = ['hypervolume', 'runtime', 'tour_length']

    # Load CSV
    print(f"\nLoading: {csv_file}")
    df = pd.read_csv(csv_file)

    # Extract metadata from data
    problem_type = df['problem_type'].iloc[0]
    problem_size = df['problem_size'].iloc[0]
    algorithms = df['algorithm'].unique().tolist()
    decomposition_sizes = sorted(df['decomposition_size'].unique().tolist())
    overlap_ratios = sorted(df['overlap_ratio'].unique().tolist())

    # Infer actual problem size from decomposition sizes (rough estimate)
    actual_size = max(decomposition_sizes) * 2  # Approximate

    print(f"\nProblem: {problem_type} ({problem_size})")
    print(f"Algorithms: {algorithms}")
    print(f"Decomposition sizes: {decomposition_sizes}")
    print(f"Overlap ratios: {overlap_ratios}")
    print(f"Plotting metrics: {plot_metrics}")

    if baseline_solvers:
        print(f"Baseline solvers: {list(baseline_solvers.keys())}")

    # Determine output directory
    if output_dir is None:
        output_dir = os.path.dirname(csv_file)
        if not output_dir:
            output_dir = '.'

    os.makedirs(output_dir, exist_ok=True)

    # Generate timestamp for new plots
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    timestamp = f"{timestamp}_{output_suffix}"

    # Generate plots
    print("\nRegenerating plots...")
    plot_from_csv_dataframe(
        df, decomposition_sizes, overlap_ratios,
        problem_type, problem_size, output_dir, timestamp, actual_size,
        plot_metrics, baseline_solvers
    )

    print("\n" + "="*80)
    print("REPLOTTING COMPLETE!")
    print(f"Check {output_dir}/ for new plots")
    print("="*80)

    return df


def plot_from_csv_dataframe(
    df: pd.DataFrame,
    decomposition_sizes: List[int],
    overlap_ratios: List[float],
    problem_type: str,
    problem_size: str,
    output_dir: str,
    timestamp: str,
    actual_size: int,
    plot_metrics: List[str],
    baseline_solvers: Dict[str, Dict[str, float]] = None
):
    """
    Create plots from the aggregated CSV DataFrame.
    
    The DataFrame should have columns:
    - algorithm, decomposition_size, overlap_ratio, overlap_percentage
    - hv_mean, hv_std, runtime_mean, runtime_std, solutions_mean, solutions_std
    - tour_length_mean, tour_length_std
    """
    
    print("\nGenerating ICML-style plots from CSV...")
    
    # Define metric configurations
    metric_configs = {
        'hypervolume': {
            'mean_col': 'hv_mean',
            'std_col': 'hv_std',
            'ylabel': 'Hypervolume',
            'marker': 'o',
            'higher_better': True
        },
        'runtime': {
            'mean_col': 'runtime_mean',
            'std_col': 'runtime_std',
            'ylabel': 'Runtime (seconds)',
            'marker': 's',
            'higher_better': False
        },
        'solutions': {
            'mean_col': 'solutions_mean',
            'std_col': 'solutions_std',
            'ylabel': 'Number of Solutions',
            'marker': '^',
            'higher_better': True
        },
        'tour_length': {
            'mean_col': 'tour_length_mean',
            'std_col': 'tour_length_std',
            'ylabel': 'Average Tour Length' if 'TSP' in problem_type else 'Average Value',
            'marker': 'd',
            'higher_better': False if 'TSP' in problem_type else True
        },
    }
    
    # Get unique algorithms
    algorithms = df['algorithm'].unique()
    
    for alg_name in algorithms:
        df_alg = df[df['algorithm'] == alg_name].copy()
        
        if df_alg.empty:
            print(f"No data for {alg_name}, skipping...")
            continue
        
        # Create figure with 3 subplots (or len(plot_metrics) if less)
        n_plots = min(3, len(plot_metrics))
        fig, axes = plt.subplots(1, n_plots, figsize=(5 * n_plots, 4.2), 
                                constrained_layout=False)
        
        # Handle single subplot case
        if n_plots == 1:
            axes = [axes]
        
        # Use high-quality colormap
        n_colors = len(decomposition_sizes)
        if n_colors <= 10:
            colors = plt.cm.tab10(np.linspace(0, 1, n_colors))
        elif n_colors <= 20:
            colors = plt.cm.tab20(np.linspace(0, 1, n_colors))
        else:
            colors = plt.cm.viridis(np.linspace(0, 1, n_colors))
        
        # Plot each selected metric
        for ax_idx, metric_name in enumerate(plot_metrics[:n_plots]):
            if metric_name not in metric_configs:
                print(f"Warning: Unknown metric '{metric_name}', skipping...")
                continue
            
            config = metric_configs[metric_name]
            ax = axes[ax_idx]
            
            # Plot line for each decomposition size
            for i, decomp_size in enumerate(decomposition_sizes):
                df_decomp = df_alg[df_alg['decomposition_size'] == decomp_size].copy()
                
                if df_decomp.empty:
                    continue
                
                # Sort by overlap percentage
                df_decomp = df_decomp.sort_values('overlap_percentage')
                
                # Get mean and std
                overlaps = df_decomp['overlap_percentage'].values
                means = df_decomp[config['mean_col']].values
                stds = df_decomp[config['std_col']].values
                
                # Handle NaN values
                valid_mask = ~np.isnan(means)
                if not valid_mask.any():
                    continue
                
                overlaps = overlaps[valid_mask]
                means = means[valid_mask]
                stds = stds[valid_mask]
                stds = np.nan_to_num(stds, nan=0.0)  # Replace NaN std with 0
                
                # Plot line
                ax.plot(overlaps, means, 
                       marker=config['marker'], linewidth=2, markersize=6,
                       color=colors[i], label=f'D={decomp_size}', alpha=0.85)
                
                # Error bands
                ax.fill_between(overlaps, 
                               means - stds,
                               means + stds,
                               alpha=0.12, color=colors[i])
            
            # Add baseline solver reference lines
            if baseline_solvers:
                baseline_metric_map = {
                    'hypervolume': 'hypervolume',
                    'runtime': 'runtime'
                }

                if metric_name in baseline_metric_map:
                    baseline_key = baseline_metric_map[metric_name]
                    baseline_colors = ['red', 'orange', 'purple', 'brown', 'green']
                    baseline_linestyles = ['--', '-.', ':', '--', '-.']
                    baseline_idx = 0

                    for solver_name, solver_metrics in baseline_solvers.items():
                        if baseline_key in solver_metrics:
                            value = solver_metrics[baseline_key]
                            color = baseline_colors[baseline_idx % len(baseline_colors)]
                            linestyle = baseline_linestyles[baseline_idx % len(baseline_linestyles)]

                            ax.axhline(y=value, color=color, linestyle=linestyle,
                                      linewidth=2, alpha=0.7,
                                      label=f'{solver_name}')
                            baseline_idx += 1
            
            # Axis labels (using global rcParams for size and weight)
            ax.set_xlabel('Overlap (% of Decomposition Size)')
            ax.set_ylabel(config['ylabel'])
            
            # Title with direction indicator
            direction = "↑ better" if config['higher_better'] else "↓ better"
            ax.set_title(f'{config["ylabel"]} ({direction})', pad=10)
            
            # Legend - place outside or adjust based on number of items
            handles, labels = ax.get_legend_handles_labels()
            n_legend_items = len(labels)
            
            if n_legend_items <= 6:
                ax.legend(frameon=True, framealpha=0.9, edgecolor='0.8',
                         loc='best', ncol=1)
            else:
                ax.legend(frameon=True, framealpha=0.9, edgecolor='0.8',
                         loc='best', ncol=2)
            
            # Grid
            ax.grid(True, alpha=0.3, linestyle='--')
        
        # Manual spacing adjustment
        plt.subplots_adjust(top=0.88, bottom=0.12, left=0.08, right=0.98, wspace=0.25)
        
        # Save
        plot_file = os.path.join(
            output_dir,
            f'ablation_overlap_{alg_name}_{problem_type}_{problem_size}_{timestamp}.png'
        )
        plt.savefig(plot_file, dpi=300, bbox_inches='tight', pad_inches=0.1)
        plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight', pad_inches=0.1)
        print(f"Saved: {plot_file}")
        plt.close()
    
    print("\nICML-style plots generated from CSV!")


def find_csv_files(directory: str, pattern: str = 'ablation_aggregated') -> List[str]:
    """
    Find all aggregated CSV files in a directory.
    
    Parameters:
    -----------
    directory : str
        Directory to search
    pattern : str
        Pattern to match in filename (default: 'ablation_aggregated')
    
    Returns:
    --------
    List[str] : List of matching file paths
    """
    import glob
    
    files = glob.glob(os.path.join(directory, f'*{pattern}*.csv'))
    return sorted(files)


# ============================================================================
# MAIN SCRIPT
# ============================================================================

if __name__ == "__main__":
    import sys
    
    # ========================================================================
    # CONFIGURATION - EDIT THIS SECTION
    # ========================================================================
    
    # Path to your saved aggregated CSV file
    # (Look for files named like: ablation_aggregated_BiTSP_small_*.csv)
    CSV_FILE = './ablation_results_complete_overlap_UCB/ablation_aggregated_BiTSP_large_20260103-035609.csv' # ablation_aggregated_BiTSP_small_20251228-133548.csv'
    
    # Metrics to plot (choose from: 'hypervolume', 'runtime', 'solutions', 'tour_length')
    PLOT_METRICS = ['hypervolume', 'runtime', 'tour_length']
    
    # Baseline solver results (optional)
    # Set to None if you don't want baseline lines
    BASELINE_SOLVERS = {
        'WS-LKH': {'hypervolume': 0.69, 'runtime': 900}, # bitsp20: 'WS-LKH': {'hypervolume': 0.625, 'runtime': 1.8},
        # 'NSGA-II': {'hypervolume': 0.48, 'runtime': 8.5},
        # 'MOEA/D': {'hypervolume': 0.49, 'runtime': 11.0},
    }
    
    # Output directory (None = same directory as CSV file)
    OUTPUT_DIR =  './replot_large' # None
    
    # Suffix for output filenames
    OUTPUT_SUFFIX = 'with_baselines'
    
    # ========================================================================
    # RUN REPLOTTING
    # ========================================================================
    
    print("\n" + "="*80)
    print("CSV-BASED REPLOTTING TOOL FOR ABLATION STUDIES")
    print("="*80)
    print(f"\nCSV File: {CSV_FILE}")
    print(f"Metrics: {PLOT_METRICS}")
    if BASELINE_SOLVERS:
        print(f"Baselines: {list(BASELINE_SOLVERS.keys())}")
    print(f"Output Suffix: {OUTPUT_SUFFIX}")
    
    # Check if file exists
    if not os.path.exists(CSV_FILE):
        print(f"\n✗ ERROR: Could not find CSV file: {CSV_FILE}")
        print("\nSearching for available CSV files...")
        
        # Try to find CSV files in the directory
        search_dir = os.path.dirname(CSV_FILE) or '.'
        if os.path.exists(search_dir):
            available_files = find_csv_files(search_dir)
            if available_files:
                print(f"\nFound {len(available_files)} aggregated CSV file(s):")
                for f in available_files:
                    print(f"  - {f}")
                print("\nPlease update CSV_FILE to one of these paths.")
            else:
                print(f"No aggregated CSV files found in {search_dir}")
        sys.exit(1)
    
    try:
        df = load_and_replot_from_csv(
            csv_file=CSV_FILE,
            plot_metrics=PLOT_METRICS,
            baseline_solvers=BASELINE_SOLVERS,
            output_dir=OUTPUT_DIR,
            output_suffix=OUTPUT_SUFFIX
        )
        
        print("\n✓ Replotting completed successfully!")
        
        # Print summary statistics
        print("\n" + "-"*60)
        print("DATA SUMMARY")
        print("-"*60)
        print(f"Total configurations: {len(df)}")
        print(f"Algorithms: {df['algorithm'].unique().tolist()}")
        print(f"Decomposition sizes: {sorted(df['decomposition_size'].unique().tolist())}")
        print(f"Overlap ratios: {sorted(df['overlap_ratio'].unique().tolist())}")
        print(f"HV range: [{df['hv_mean'].min():.4f}, {df['hv_mean'].max():.4f}]")
        print(f"Runtime range: [{df['runtime_mean'].min():.2f}, {df['runtime_mean'].max():.2f}] seconds")
        
    except Exception as e:
        print(f"\n✗ ERROR: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)