#!/usr/bin/env python
"""
This module provides functions to create publication-quality plots for regret, violation, 
and fidelity iteration distributions.
"""

import os
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Dict, List, Tuple

# Set serif font with fallbacks (Times New Roman, Liberation Serif, DejaVu Serif)
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman', 'Liberation Serif', 'DejaVu Serif', 'serif']
plt.rcParams['mathtext.fontset'] = 'dejavuserif'
plt.rcParams['mathtext.rm'] = 'Times New Roman'


def regret_plot(
    csv_file: str,
    method_labels: Optional[Dict[str, str]] = None,
    method_colors: Optional[Dict[str, str]] = None,
    method_linestyles: Optional[Dict[str, str]] = None,
    method_order: Optional[List[str]] = None,
    figsize: Tuple[float, float] = (2.3, 2.3),
    title: Optional[str] = None,
    xlabel: str = "Cost",
    ylabel: str = " ",
    title_fontsize: int = 12,
    label_fontsize: int = 12,
    tick_fontsize: int = 12,
    legend_fontsize: int = 13,
    x_tick_step: Optional[float] = None,
    y_tick_step: Optional[float] = None,
    linewidth: float = 2.5,
    alpha_fill: float = 0.2,
    use_step_plot: bool = True,
    grid: bool = True,
    grid_alpha: float = 0.3,
    show_legend: bool = True,
    legend_loc: str = "best",
    legend_ncol: int = 1,
    legend_frameon: bool = True,
    save_path: str = "auto",
    dpi: int = 300,
    use_normalized_cost: bool = False,
    show_plot: bool = False,
    subplot_left: float = 0.27,
    subplot_right: float = 0.98,
    subplot_top: float = 0.88,
    subplot_bottom: float = 0.225,
):
    """
    Create publication-ready regret vs cost plot from CSV file.
    
    Args:
        csv_file: Name of CSV file (looks in same directory as this script)
        method_labels: Dict mapping method names to display labels
        method_colors: Dict mapping method names to colors
        method_linestyles: Dict mapping method names to line styles ('-', '--', '-.', ':')
        method_order: List of method names in desired plotting order
        figsize: Figure size (width, height)
        title: Plot title (None for no title)
        xlabel: X-axis label
        ylabel: Y-axis label
        title_fontsize: Font size for title
        label_fontsize: Font size for axis labels
        tick_fontsize: Font size for tick labels
        legend_fontsize: Font size for legend
        x_tick_step: Step size for x-axis ticks (e.g., 10 for ticks at 0, 10, 20, ...)
        linewidth: Width of the mean line
        alpha_fill: Transparency for std band (0-1)
        use_step_plot: Use step plot instead of smooth line
        grid: Whether to show grid
        grid_alpha: Grid transparency (0-1)
        show_legend: Whether to show legend
        legend_loc: Legend location ('upper right', 'best', 'outside top', etc.)
        legend_ncol: Number of columns in legend
        legend_frameon: Whether to draw frame around legend
        save_path: Path to save figure (None to not save)
        dpi: DPI for saved figure
        use_normalized_cost: Use 'cost_normalized' column if available
    """
    
    # Default method labels
    if method_labels is None:
        method_labels = {
            "rescue": "RESCUE",
            "mf_gp_momf": "MOMF",
            "mf_gp_hvkg": "MF-HVKG",
            "mf_gp_qehvi": "MOBO-qEHVI",
        }
    
    # Default method colors (color-blind friendly)
    if method_colors is None:
        method_colors = {
            "rescue": "#D62728",      # Red (rescue - your method)
            "mf_gp_momf": "#1F77B4",  # Blue (matplotlib default blue)
            "mf_gp_hvkg": "#2CA02C",  # Green (matplotlib default green)
            "mf_gp_qehvi": "#FF7F0E", # Orange (matplotlib default orange)
        }
    
    # Default method linestyles
    if method_linestyles is None:
        method_linestyles = {
            "rescue": "--",
            "mf_gp_momf": ":",
            "mf_gp_hvkg": "-.",
            "mf_gp_qehvi": "-",
        }
    
    # Default method order
    if method_order is None:
        method_order = ["rescue", "mf_gp_hvkg", "mf_gp_momf", "mf_gp_qehvi"]
    
    # Load CSV
    script_dir = os.path.dirname(os.path.abspath(__file__))
    csv_path = os.path.join(script_dir, csv_file)
    
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"CSV file not found: {csv_path}")
    
    # Auto-generate save path if set to "auto"
    if save_path == "auto":
        base_name = os.path.splitext(os.path.basename(csv_file))[0]
        save_path = os.path.join(script_dir, "figs", f"{base_name}_plot.pdf")
    
    df = pd.read_csv(csv_path)
    
    # Get cost column
    cost_col = "cost_normalized" if use_normalized_cost and "cost_normalized" in df.columns else "cost"
    cost = df[cost_col].values
    
    # Extract methods from CSV
    methods = [col.replace("_mean", "") for col in df.columns if col.endswith("_mean")]
    methods = [m for m in method_order if m in methods]  # Order them
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Track data bounds for y-axis
    all_means = []
    all_stds = []
    
    for method in methods:
        mean = df[f"{method}_mean"].values
        std = df[f"{method}_std"].values
        all_means.append(mean)
        all_stds.append(std)
        
        label = method_labels.get(method, method)
        color = method_colors.get(method, None)
        linestyle = method_linestyles.get(method, '-')
        
        if use_step_plot:
            ax.step(cost, mean, label=label, color=color, linewidth=linewidth, where='post', linestyle=linestyle)
            ax.fill_between(cost, mean - std, mean + std, color=color, alpha=alpha_fill, step='post')
        else:
            ax.plot(cost, mean, label=label, color=color, linewidth=linewidth, linestyle=linestyle)
            ax.fill_between(cost, mean - std, mean + std, color=color, alpha=alpha_fill)
    
    # Format plot
    ax.set_xlabel(xlabel, fontsize=label_fontsize)
    ax.set_ylabel(ylabel, fontsize=label_fontsize)
    if title:
        ax.set_title(title, fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    
    # Set x-axis tick step if specified
    if x_tick_step is not None:
        import numpy as np
        x_min, x_max = ax.get_xlim()
        ticks = np.arange(0, x_max, x_tick_step)
        # Only include ticks up to the actual data max
        ticks = ticks[ticks <= cost.max()]
        ax.set_xticks(ticks)
    
    # Set y-axis tick step if specified
    if y_tick_step is not None:
        import numpy as np
        # Calculate data bounds
        all_means_arr = np.concatenate(all_means)
        all_stds_arr = np.concatenate(all_stds)
        data_min = np.min(all_means_arr - all_stds_arr)
        data_max = np.max(all_means_arr + all_stds_arr)
        # Generate ticks bounded by actual data
        ticks = np.arange(np.floor(data_min / y_tick_step) * y_tick_step, data_max + y_tick_step, y_tick_step)
        ticks = ticks[(ticks >= data_min) & (ticks <= data_max)]
        ax.set_yticks(ticks)
    
    if grid:
        ax.grid(alpha=grid_alpha, linestyle='-', linewidth=0.5)
    
    if show_legend:
        if legend_loc == "outside top":
            ax.legend(
                loc='upper center',
                bbox_to_anchor=(0.5, 1.02),
                fontsize=legend_fontsize,
                ncol=legend_ncol,
                frameon=legend_frameon
            )
        else:
            ax.legend(loc=legend_loc, fontsize=legend_fontsize, ncol=legend_ncol, frameon=legend_frameon)
    
    # Adjust layout to fit within figure size (ensures consistent PDF dimensions)
    plt.subplots_adjust(left=subplot_left, right=subplot_right, top=subplot_top, bottom=subplot_bottom)
    
    if save_path:
        fig.savefig(save_path, dpi=dpi)
        print(f"Figure saved to: {save_path}")
    
    if show_plot:
        plt.show()
    return fig


def violation_plot(
    csv_file: str,
    method_labels: Optional[Dict[str, str]] = None,
    method_colors: Optional[Dict[str, str]] = None,
    method_hatches: Optional[Dict[str, str]] = False,
    method_order: Optional[List[str]] = None,
    figsize: Tuple[float, float] = (3.2, 1.5),
    title: Optional[str] = None,
    ylabel: str = " ",
    title_fontsize: int = 12,
    label_fontsize: int = 14,
    tick_fontsize: int = 12,
    x_tick_fontsize: Optional[int] = None,
    y_tick_fontsize: Optional[int] = 14,
    datalabel_fontsize: Optional[int] = 14,
    y_tick_step: Optional[float] = 25,
    alpha: float = 0.5,
    edgecolor: str = 'black',
    linewidth: float = 1.2,
    errorbar_linewidth: float = 1,
    errorbar_capsize: float = 8,
    errorbar_capthick: float = 1,
    grid: bool = True,
    grid_alpha: float = 0.3,
    rotation: int = 0,
    show_xlabels: bool = True,
    save_path: str = "auto",
    dpi: int = 300,
    show_plot: bool = False,
    subplot_left: float = 0.18,
    subplot_right: float = 0.99,
    subplot_top: float = 0.93,
    subplot_bottom: float = 0.18,
):
    """
    Create publication-ready violation bar plot from CSV file.
    
    Args:
        csv_file: Name of CSV file (looks in same directory as this script)
        method_labels: Dict mapping method names to display labels
        method_colors: Dict mapping method names to colors
        method_hatches: Dict mapping method names to hatch patterns ('/', '\\', '|', '-', '+', 'x', 'o', 'O', '.', '*')
        method_order: List of method names in desired plotting order
        figsize: Figure size (width, height)
        title: Plot title (None for no title)
        xlabel: X-axis label
        ylabel: Y-axis label
        title_fontsize: Font size for title
        label_fontsize: Font size for axis labels
        tick_fontsize: Font size for tick labels (used as default for x and y ticks)
        x_tick_fontsize: Font size for x-axis tick labels (None uses tick_fontsize)
        y_tick_fontsize: Font size for y-axis tick labels (None uses tick_fontsize)
        datalabel_fontsize: Font size for data labels on bars (None uses tick_fontsize-2)
        y_tick_step: Step size for y-axis ticks
        alpha: Bar transparency (0-1)
        edgecolor: Color of bar edges
        linewidth: Width of bar edges
        errorbar_linewidth: Width of the error bar lines (default: 1)
        errorbar_capsize: Width of the error bar caps (default: 10)
        errorbar_capthick: Thickness of the error bar caps (default: 1)
        grid: Whether to show grid
        grid_alpha: Grid transparency (0-1)
        rotation: Rotation angle for x-tick labels
        show_xlabels: Whether to show x-axis method labels
        save_path: Path to save figure (None to not save)
        dpi: DPI for saved figure
    """
    
    # Default method labels
    if method_labels is None:
        method_labels = {
            "rescue": "RESCUE",
            "mf_gp_momf": "MOMF",
            "mf_gp_hvkg": "HVKG",
            "mf_gp_qehvi": "qEHVI",
        }
    
    # Default method colors (same as regret_plot)
    if method_colors is None:
        method_colors = {
            "rescue": "#D62728",      # Red (rescue - your method)
            "mf_gp_momf": "#1F77B4",  # Blue
            "mf_gp_hvkg": "#2CA02C",  # Green
            "mf_gp_qehvi": "#FF7F0E", # Orange
        }
    
    # Default method hatches (different patterns for each method)
    if method_hatches is None:
        method_hatches = {
            "rescue": "",       # No hatch (solid)
            "mf_gp_momf": "//",    # Diagonal lines
            "mf_gp_hvkg": "\\\\",   # Reverse diagonal lines
            "mf_gp_qehvi": "xx",   # Crosshatch
        }
    elif method_hatches is False:
        method_hatches = {}
    
    # Default method order
    if method_order is None:
        method_order = ["rescue", "mf_gp_hvkg", "mf_gp_momf", "mf_gp_qehvi"]
    
    # Load CSV
    script_dir = os.path.dirname(os.path.abspath(__file__))
    csv_path = os.path.join(script_dir, csv_file)
    
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"CSV file not found: {csv_path}")
    
    # Auto-generate save path if set to "auto"
    if save_path == "auto":
        base_name = os.path.splitext(os.path.basename(csv_file))[0]
        save_path = os.path.join(script_dir, "figs", f"{base_name}_plot.pdf")
    
    df = pd.read_csv(csv_path)
    
    # Filter methods based on what's in the CSV and method_order
    available_methods = df['method'].values
    methods = [m for m in method_order if m in available_methods]
    
    # Get data for ordered methods
    means = []
    stds = []
    labels = []
    colors = []
    hatches = []
    
    for method in methods:
        method_row = df[df['method'] == method]
        means.append(method_row['mean_violations'].values[0] * 100)
        stds.append(method_row['std_violations'].values[0] * 100)
        labels.append(method_labels.get(method, method))
        colors.append(method_colors.get(method, None))
        hatches.append(method_hatches.get(method, '') if method_hatches else '')
    
    # Sort by mean violations (highest to lowest)
    sorted_indices = sorted(range(len(means)), key=lambda i: means[i], reverse=True)
    means = [means[i] for i in sorted_indices]
    stds = [stds[i] for i in sorted_indices]
    labels = [labels[i] for i in sorted_indices]
    colors = [colors[i] for i in sorted_indices]
    hatches = [hatches[i] for i in sorted_indices]
    
    # Create plot
    import numpy as np
    fig, ax = plt.subplots(figsize=figsize)
    
    x_pos = np.arange(len(methods))
    bars = ax.bar(x_pos, means, yerr=stds, capsize=errorbar_capsize, alpha=alpha, 
                   error_kw={'linewidth': errorbar_linewidth, 'capthick': errorbar_capthick},
                   color=colors, edgecolor=edgecolor, linewidth=linewidth)
    
    # Apply hatch patterns to bars
    for bar, hatch in zip(bars, hatches):
        bar.set_hatch(hatch)
    
    # Adjust y-axis limit to accommodate labels BEFORE adding text labels
    y_max = max(m + s for m, s in zip(means, stds))
    ax.set_ylim(0, y_max * 1.25)  # Add 20% padding above the highest error bar
    
    # Add data labels on top of bars
    label_font = datalabel_fontsize if datalabel_fontsize is not None else tick_fontsize - 2
    for i, (bar, mean, std) in enumerate(zip(bars, means, stds)):
        height = bar.get_height()
        # Position label above the error bar with small offset
        label_y = height + std + (y_max * 0.015)
        ax.text(bar.get_x() + bar.get_width()/2., label_y,
                f'{mean:.1f}',
                ha='center', va='bottom', fontsize=label_font, clip_on=True)
    
    # Format plot
    # ax.set_xlabel(xlabel, fontsize=label_fontsize)
    ax.set_ylabel(ylabel, fontsize=label_fontsize)
    if title:
        ax.set_title(title, fontsize=title_fontsize)
    
    if show_xlabels:
        x_label_fontsize = x_tick_fontsize if x_tick_fontsize is not None else tick_fontsize
        ax.set_xticks(x_pos)
        ax.set_xticklabels(labels, rotation=rotation, ha='center', fontsize=x_label_fontsize)
    else:
        ax.set_xticks([])
    
    y_label_fontsize = y_tick_fontsize if y_tick_fontsize is not None else tick_fontsize
    ax.tick_params(axis='y', which='major', labelsize=y_label_fontsize)
    
    # Set y-axis tick step if specified
    if y_tick_step is not None:
        import numpy as np
        y_min, y_max_limit = ax.get_ylim()
        # Cap ticks at 100% to avoid exceeding percentage scale
        ticks = np.arange(0, min(y_max_limit, 100 + y_tick_step), y_tick_step)
        ticks = ticks[ticks <= 100]
        ax.set_yticks(ticks)
    
    if grid:
        ax.grid(axis='y', alpha=grid_alpha, linestyle='-', linewidth=0.5)
    
    # Adjust layout to fit within figure size (ensures consistent PDF dimensions)
    plt.subplots_adjust(left=subplot_left, right=subplot_right, top=subplot_top, bottom=subplot_bottom)
    
    if save_path:
        fig.savefig(save_path, dpi=dpi)
        print(f"Figure saved to: {save_path}")
    
    if show_plot:
        plt.show()
    return fig


def plot_legend(
    rows: int,
    columns: int,
    methods: Optional[List[str]] = None,
    save_path: str = "legend.pdf",
    figsize: Tuple[float, float] = (3, 3),
    legend_fontsize: int = 11,
    dpi: int = 300,
    use_bar_style: bool = False,
    bar_alpha: float = 0.5,
    show_hatches: bool = False,
    frame_edgecolor: str = 'gray',
    frame_linewidth: float = 0.5,
    frame_linestyle: str = '-'
):
    """
    Create and save a standalone legend as a PDF.

    Args:
        rows: Number of rows in the legend.
        columns: Number of columns in the legend.
        methods: Optional list of method names to include in the legend.
        save_path: Path to save the legend PDF.
        figsize: Figure size (width, height).
        legend_fontsize: Font size for legend text.
        dpi: DPI for saved figure.
        use_bar_style: If True, use bar patches with method colors (for fidelity_iteration_plot style)
        bar_alpha: Alpha transparency for bar patches when use_bar_style=True
        show_hatches: Whether to show hatch patterns on bar-style patches
        frame_edgecolor: Color of the legend frame border
        frame_linewidth: Width of the legend frame border
        frame_linestyle: Style of the legend frame border ('-', '--', '-.', ':')
    """
    import matplotlib.pyplot as plt
    from matplotlib.patches import Patch

    # Default method labels, colors, and linestyles
    method_labels = {
        "rescue": "RESCUE",
        "mf_gp_momf": "MOMF",
        "mf_gp_hvkg": "HVKG",
        "mf_gp_qehvi": "qEHVI",
    }
    method_colors = {
        "rescue": "#D62728",      # Red (rescue - your method)
        "mf_gp_momf": "#1F77B4",  # Blue (matplotlib default blue)
        "mf_gp_hvkg": "#2CA02C",  # Green (matplotlib default green)
        "mf_gp_qehvi": "#FF7F0E", # Orange (matplotlib default orange)
    }
    method_linestyles = {
        "rescue": "--",
        "mf_gp_momf": ":",
        "mf_gp_hvkg": "-.",
        "mf_gp_qehvi": "-",
    }
    method_hatches = {
        "rescue": "",       # No hatch (solid)
        "mf_gp_momf": "//",    # Diagonal lines
        "mf_gp_hvkg": "\\\\",   # Reverse diagonal lines
        "mf_gp_qehvi": "xx",   # Crosshatch
    }

    # Use all methods if none are specified
    if methods is None:
        methods = list(method_labels.keys())

    # Create a dummy figure for the legend
    fig, ax = plt.subplots(figsize=figsize)
    
    if use_bar_style:
        # Use bar/patch style for fidelity_iteration_plot
        legend_elements = [Patch(facecolor=method_colors.get(m, '#1f77b4'), 
                                 label=method_labels.get(m, m), 
                                 edgecolor='black', linewidth=1.2,
                                 alpha=bar_alpha,
                                 hatch=method_hatches.get(m, '') if show_hatches else '') 
                          for m in methods]
        legend = ax.legend(handles=legend_elements,
                          loc='center',
                          fontsize=legend_fontsize,
                          ncol=columns,
                          frameon=True)
    else:
        # Use line style for regret plots
        for method in methods:
            ax.plot([], [],
                    label=method_labels.get(method, method),
                    color=method_colors.get(method, None),
                    linestyle=method_linestyles.get(method, '-'))
        legend = ax.legend(
            loc='center',
            fontsize=legend_fontsize,
            ncol=columns,
            frameon=True
        )
    
    # Customize legend frame
    frame = legend.get_frame()
    frame.set_edgecolor(frame_edgecolor)
    frame.set_linewidth(frame_linewidth)
    frame.set_linestyle(frame_linestyle)
    
    ax.axis('off')  # Turn off the axis

    # Adjust figure size based on rows and columns
    fig.set_size_inches(figsize[0] * columns, figsize[1] * rows)

    # Save the legend as a PDF
    fig.savefig(save_path, dpi=dpi)
    print(f"Legend saved to: {save_path}")
    plt.close(fig)


def fidelity_iteration_plot(
    csv_file: str,
    method_labels: Optional[Dict[str, str]] = None,
    method_colors: Optional[Dict[str, str]] = None,
    method_hatches: Optional[Dict[str, str]] = False,
    method_order: Optional[List[str]] = None,
    figsize: Tuple[float, float] = (3.5, 2.2),
    title: Optional[str] = None,
    iteration_ylabel: str = "Iterations",
    fidelity_ylabel: str = "Fidelity",
    colormap: str = "autumn",
    iteration_bar_alpha: float = 0.5,
    scatter_size: float = 3,
    scatter_alpha: float = 0.3,
    violin_linewidth: float = 1,
    violin_line_color: str = 'black',
    violin_body_linewidth: float = 1.0,
    violin_body_alpha: float = 0.5,
    violin_body_facecolor: str = 'tab:purple',
    violin_widths: float = 0.8,
    errorbar_linewidth: float = 1,
    errorbar_capsize: float = 10,
    errorbar_capthick: float = 1,
    grid: bool = True,
    grid_alpha: float = 0.3,
    label_fontsize: int = 16,
    tick_fontsize: int = 16,
    fidelity_x_tick_fontsize: Optional[int] = 12.5,
    fidelity_y_tick_fontsize: Optional[int] = None,
    iteration_datalabel_fontsize: Optional[int] = 15,
    show_xlabels: bool = True,
    color_violins_by_method: bool = True,
    fidelity_tick_step: Optional[float] = 0.5,
    iteration_tick_step: Optional[float] = 15,
    title_fontsize: int = 14,
    legend_fontsize: int = 10,
    show_legend: bool = False,
    legend_loc: str = "best",
    iteration_subplot_height: float = 1.0,
    xlabel_rotation: int = 0,
    save_path: str = "auto",
    dpi: int = 300,
    show_plot: bool = False,
    show_scatter: bool = False,
    show_colorbar: bool = True,
    subplot_left: float = 0.15,
    subplot_right: float = 0.99,
    subplot_top: float = 0.97,
    subplot_bottom: float = 0.125,
    subplot_hspace: float = 0.15,
    discrete_fidelities: Optional[List[float]] = None,
):
    """
    Violin plot for fidelity with scatter points colored by iteration order, and line plot for iterations on 2nd y-axis.
    
    Args:
        csv_file: CSV with columns: method, fidelity, iteration_mean, iteration_std
        method_labels: Dict mapping method names to display labels
        method_colors: Dict mapping method names to colors
        method_order: List of method names in desired plotting order
        figsize: Figure size
        title: Plot title
        iteration_ylabel: Y-axis label for iteration subplot (top)
        fidelity_ylabel: Y-axis label for fidelity subplot (bottom)
        colormap: Colormap for scatter points (e.g., 'jet', 'viridis', 'plasma')
        iteration_bar_alpha: Transparency for iteration bars
        scatter_size: Size of scatter points
        scatter_alpha: Transparency for scatter points (0-1)
        violin_linewidth: Width of the violin plot statistical lines (bars, mins, maxes, means)
        violin_line_color: Color of the violin plot statistical lines (bars, mins, maxes, means)
        violin_body_linewidth: Width of the violin body (face) outline
        violin_body_alpha: Transparency for violin body (face) fill (0-1)
        violin_body_facecolor: Color for violin body (face) fill (e.g., 'C0', '#1f77b4', 'blue')
        violin_widths: Width of the violin plots (default: 0.5)
        errorbar_linewidth: Width of the error bar lines on iteration bars (default: 2)
        errorbar_capsize: Width of the error bar caps on iteration bars (default: 5)
        errorbar_capthick: Thickness of the error bar caps on iteration bars (default: 2)
        grid: Whether to show grid
        grid_alpha: Grid transparency (0-1)
        label_fontsize: Font size for axis labels
        tick_fontsize: Font size for tick labels (used as default)
        fidelity_x_tick_fontsize: Font size for fidelity x-axis tick labels (None uses tick_fontsize)
        fidelity_y_tick_fontsize: Font size for fidelity y-axis tick labels (None uses tick_fontsize)
        iteration_datalabel_fontsize: Font size for iteration bar data labels (None uses tick_fontsize-1)
        show_xlabels: Whether to show x-axis method labels
        color_violins_by_method: If True, color violin bodies by method colors (matches bar colors)
        title_fontsize: Font size for title
        legend_fontsize: Font size for legend
        show_legend: Whether to show legend
        legend_loc: Legend location
        save_path: Path to save figure
        dpi: DPI for saved figure
        show_plot: Whether to display plot
        show_scatter: Whether to show scatter points on violin plot
        show_colorbar: Whether to show colorbar for scatter points
        subplot_hspace: Vertical space between the two subplots (default: 0.1)
        discrete_fidelities: Optional list of discrete fidelity values. If provided, only these values will be shown as y-axis ticks.
    """
    import numpy as np
    import matplotlib.cm as cm
    from matplotlib.colors import Normalize
    from matplotlib.patches import Patch
    
    # Default method labels
    if method_labels is None:
        method_labels = {
            "rescue": "RESCUE",
            "mf_gp_momf": "MOMF",
            "mf_gp_hvkg": "HVKG",
            "mf_gp_qehvi": "qEHVI",
        }
    
    # Default method colors (same as regret_plot)
    if method_colors is None:
        method_colors = {
            "rescue": "#D62728",      # Red
            "mf_gp_momf": "#1F77B4",  # Blue
            "mf_gp_hvkg": "#2CA02C",  # Green
            "mf_gp_qehvi": "#FF7F0E", # Orange
        }
    
    # Default method hatches (different patterns for each method)
    if method_hatches is None:
        method_hatches = {
            "rescue": "",       # No hatch (solid)
            "mf_gp_momf": "//",    # Diagonal lines
            "mf_gp_hvkg": "\\\\",   # Reverse diagonal lines
            "mf_gp_qehvi": "xx",   # Crosshatch
        }
    elif method_hatches is False:
        method_hatches = {}
    
    # Default method order
    if method_order is None:
        method_order = ["rescue", "mf_gp_hvkg", "mf_gp_momf", "mf_gp_qehvi"]
    
    # Load data
    script_dir = os.path.dirname(os.path.abspath(__file__))
    csv_path = os.path.join(script_dir, csv_file)
    
    df = pd.read_csv(csv_path)
    
    # Get unique methods and order them
    available_methods = df[df['iteration_mean'].notna()]['method'].unique()
    methods = [m for m in method_order if m in available_methods]
    
    # Prepare violin plot data - get all raw fidelity values per method
    fidelity_data = [df[df['method'] == m]['fidelity'].dropna().values for m in methods]
    
    # Get iteration stats for each method (should be same for all rows of that method)
    iter_means = [df[df['method'] == m]['iteration_mean'].dropna().iloc[0] for m in methods]
    iter_stds = [df[df['method'] == m]['iteration_std'].dropna().iloc[0] for m in methods]
    
    # Create figure with subplots (iteration bar on top, fidelity violin on bottom)
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, 1, height_ratios=[iteration_subplot_height, figsize[1] - iteration_subplot_height], 
                          hspace=subplot_hspace, left=subplot_left, right=subplot_right, 
                          top=subplot_top, bottom=subplot_bottom)
    
    ax_iter = fig.add_subplot(gs[0])
    ax_fid = fig.add_subplot(gs[1], sharex=ax_iter)
    
    # Top subplot: Bar plot for iterations
    x_pos = np.arange(1, len(methods) + 1)
    colors_list = [method_colors.get(m, '#1f77b4') for m in methods]
    hatches_list = [method_hatches.get(m, '') if method_hatches else '' for m in methods]
    bars = ax_iter.bar(x_pos, iter_means, yerr=iter_stds, capsize=errorbar_capsize, 
                       alpha=iteration_bar_alpha, error_kw={'linewidth': errorbar_linewidth, 'capthick': errorbar_capthick},
                       color=colors_list, edgecolor='black', linewidth=1.2)
    
    # Apply hatch patterns to iteration bars
    for bar, hatch in zip(bars, hatches_list):
        bar.set_hatch(hatch)
    
    # Adjust y-axis limit to accommodate labels BEFORE adding text labels
    y_max = max(m + s for m, s in zip(iter_means, iter_stds))
    ax_iter.set_ylim(0, y_max * 1.5)  # Add 70% padding above the highest error bar
    
    # Add data labels on bars
    iter_label_font = iteration_datalabel_fontsize if iteration_datalabel_fontsize is not None else tick_fontsize - 1
    for i, (bar, mean, std) in enumerate(zip(bars, iter_means, iter_stds)):
        height = bar.get_height()
        label_y = height + std + (y_max * 0.02)  # Add small offset above error bar
        ax_iter.text(bar.get_x() + bar.get_width()/2., label_y,
                    f'{mean:.1f}', ha='center', va='bottom', fontsize=iter_label_font, clip_on=False)
    
    ax_iter.set_ylabel(iteration_ylabel, fontsize=label_fontsize)
    ax_iter.tick_params(axis='y', which='major', labelsize=tick_fontsize)
    ax_iter.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
    
    # Set y-tick step for iterations if specified
    if iteration_tick_step is not None:
        y_min, y_max_limit = ax_iter.get_ylim()
        ticks = np.arange(0, y_max_limit, iteration_tick_step)
        ax_iter.set_yticks(ticks)
    
    # Format y-axis tick labels to show one decimal place
    from matplotlib.ticker import FuncFormatter
    ax_iter.yaxis.set_major_formatter(FuncFormatter(lambda x, p: f'{x:.1f}'))
    
    if grid:
        ax_iter.grid(axis='y', alpha=grid_alpha, linestyle='-', linewidth=0.5)
    
    # Bottom subplot: Violin plot for fidelity with method colors
    parts = ax_fid.violinplot(fidelity_data, positions=x_pos, widths=violin_widths,
                              showmeans=True, showmedians=False)
    
    # Color violin bodies either by method or use single color, and apply hatch patterns
    for i, pc in enumerate(parts['bodies']):
        if color_violins_by_method:
            pc.set_facecolor(colors_list[i])
        else:
            pc.set_facecolor(violin_body_facecolor)
        pc.set_alpha(violin_body_alpha)
        pc.set_linewidth(violin_body_linewidth)
        pc.set_edgecolor('black')  # Ensure edge color is visible
        pc.set_hatch(hatches_list[i])  # Apply hatch pattern
    
    # Set linewidth and color for other violin parts (cbars, cmins, cmaxes, cmeans)
    for partname in ('cbars', 'cmins', 'cmaxes', 'cmeans'):
        if partname in parts:
            parts[partname].set_linewidth(violin_linewidth)
            parts[partname].set_edgecolor(violin_line_color)
    
    # Add x-axis labels
    if show_xlabels:
        fid_x_tick_font = fidelity_x_tick_fontsize if fidelity_x_tick_fontsize is not None else tick_fontsize
        ax_fid.set_xticks(x_pos)
        ax_fid.set_xticklabels([method_labels.get(m, m) for m in methods], 
                               fontsize=fid_x_tick_font, rotation=xlabel_rotation, ha='center')
    else:
        ax_fid.set_xticks([])
    ax_fid.set_ylabel(fidelity_ylabel, fontsize=label_fontsize)
    
    # Set y-axis limits and ticks based on discrete_fidelities
    if discrete_fidelities is not None:
        sorted_fidelities = sorted(discrete_fidelities)
        min_fid = min(sorted_fidelities)
        max_fid = max(sorted_fidelities)
        padding = (max_fid - min_fid) * 0.05 if max_fid > min_fid else 0.05
        ax_fid.set_ylim(min_fid - padding, max_fid + padding)
        ax_fid.set_yticks(sorted_fidelities)
    else:
        ax_fid.set_ylim(-0.05, 1.08)
        if fidelity_tick_step is not None:
            ticks = np.arange(0, 1.0, fidelity_tick_step)
            # Always include 1.0 as the maximum tick
            if 1.0 not in ticks:
                ticks = np.append(ticks, 1.0)
            ax_fid.set_yticks(ticks)
    
    fid_y_tick_font = fidelity_y_tick_fontsize if fidelity_y_tick_fontsize is not None else tick_fontsize
    ax_fid.tick_params(axis='y', which='major', labelsize=fid_y_tick_font)
    
    # Format y-axis tick labels to show two decimal places
    from matplotlib.ticker import FuncFormatter
    ax_fid.yaxis.set_major_formatter(FuncFormatter(lambda x, p: f'{x:.2f}'))
    
    # Add scatter points and colorbar only if show_scatter is True
    if show_scatter:
        # Check if 'iteration' column exists in the CSV
        has_iteration_col = 'iteration' in df.columns
        
        # Find min and max iteration values across all data points for consistent colormap
        all_iter_values = []
        for method in methods:
            method_data = df[df['method'] == method]
            if has_iteration_col:
                # Use actual iteration values from the data
                iter_vals = method_data['iteration'].dropna().values
                all_iter_values.extend(iter_vals)
            else:
                # Fallback: use sequential indices
                n_points = len(method_data['fidelity'].dropna())
                all_iter_values.extend(range(1, n_points + 1))
        
        min_iter = min(all_iter_values)
        max_iter = max(all_iter_values)
        
        # Set random seed for reproducible jitter
        np.random.seed(42)
        
        # Add scatter points with jitter and colormap based on iteration values
        cmap = plt.colormaps[colormap]
        scatter_collections = []
        
        for i, method in enumerate(methods):
            # Get fidelity data for this method
            method_data = df[df['method'] == method]
            fid_vals = method_data['fidelity'].dropna().values
            
            if has_iteration_col:
                # Use actual iteration values from the CSV
                iter_values = method_data['iteration'].dropna().values
            else:
                # Fallback: use sequential indices (starting from 1)
                iter_values = np.arange(1, len(fid_vals) + 1)
            
            x = np.random.normal(i+1, 0.04, size=len(fid_vals))  # Jitter
            sc = ax_fid.scatter(x, fid_vals, c=iter_values, cmap=colormap, alpha=scatter_alpha, s=scatter_size, zorder=3,
                               vmin=min_iter, vmax=max_iter)
            scatter_collections.append(sc)
        
        # Add colorbar for scatter points (attach to both axes to maintain width)
        if show_colorbar:
            norm = Normalize(vmin=min_iter, vmax=max_iter)
            sm = cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([])
            cbar = plt.colorbar(sm, ax=[ax_iter, ax_fid], pad=0.02, aspect=30)
            cbar.set_label('Iteration', fontsize=label_fontsize)
            cbar.ax.tick_params(labelsize=tick_fontsize)
    
    if grid:
        ax_fid.grid(alpha=grid_alpha, linestyle='-', linewidth=0.5)
    
    if title:
        ax_iter.set_title(title, fontsize=title_fontsize, pad=10)
    
    # Add legend for methods
    if show_legend:
        legend_elements = [Patch(facecolor=method_colors.get(m, '#1f77b4'), 
                                 label=method_labels.get(m, m), alpha=0.7) 
                          for m in methods]
        ax_fid.legend(handles=legend_elements, loc=legend_loc, fontsize=legend_fontsize, frameon=True)
    
    # Note: Layout is already adjusted via gridspec parameters (left, right, top, bottom)
    
    # Auto-generate save path
    if save_path == "auto":
        base_name = os.path.splitext(os.path.basename(csv_file))[0]
        save_path = os.path.join(script_dir, "figs", f"{base_name}_plot.pdf")
    
    if save_path:
        fig.savefig(save_path, dpi=dpi)
        print(f"Figure saved to: {save_path}")
    
    if show_plot:
        plt.show()
    
    return fig


def runtime_boxplot(
    csv_file: str,
    method_labels: Optional[Dict[str, str]] = None,
    method_order: Optional[List[str]] = None,
    figsize: Tuple[float, float] = (2.5, 2),
    title: Optional[str] = None,
    ylabel: str = "Wall-clock time (s)",
    title_fontsize: int = 12,
    label_fontsize: int = 12,
    tick_fontsize: int = 12,
    grid: bool = True,
    grid_alpha: float = 0.3,
    save_path: str = "auto",
    dpi: int = 300,
    show_plot: bool = False,
    subplot_left: float = 0.2,
    subplot_right: float = 0.98,
    subplot_top: float = 0.98,
    subplot_bottom: float = 0.32,
    boxplot_width: float = 0.5,
    median_linewidth: float = 1.0,
    box_linewidth: float = 1.0,
    whisker_linewidth: float = 1.0,
    median_color: str = 'black',
    box_edgecolor: str = 'black',
):
    """
    Create publication-ready runtime boxplot from CSV file (no colors).
    
    Args:
        csv_file: Name of CSV file with columns 'method' and 'runtime'
        method_labels: Dict mapping method names to display labels
        method_order: List of method names in desired plotting order
        figsize: Figure size (width, height)
        title: Plot title (None for no title)
        ylabel: Y-axis label
        title_fontsize: Font size for title
        label_fontsize: Font size for axis labels
        tick_fontsize: Font size for tick labels
        grid: Whether to show grid
        grid_alpha: Grid transparency (0-1)
        save_path: Path to save figure ('auto' generates from csv_file name)
        dpi: DPI for saved figure
        show_plot: Whether to display the plot
        subplot_left/right/top/bottom: Subplot positioning
        boxplot_width: Width of boxes (0-1)
        median_linewidth: Width of median line
        box_linewidth: Width of box edges
        whisker_linewidth: Width of whiskers
        median_color: Color of median line
        box_edgecolor: Color of box edges
    """
    
    # Get script directory for file paths
    script_dir = os.path.dirname(os.path.abspath(__file__))
    csv_path = os.path.join(script_dir, csv_file)
    
    # Read CSV
    df = pd.read_csv(csv_path)
    
    # Auto-generate save path BEFORE other processing
    if save_path == "auto":
        # Extract just the filename without path for save_path generation
        base_name = os.path.splitext(os.path.basename(csv_file))[0]
        save_path = os.path.join(script_dir, "figs", f"{base_name}_plot.pdf")
    
    # Default method labels
    if method_labels is None:
        method_labels = {
            "rescue": "RESCUE",
            "mf_gp_momf": "MOMF",
            "mf_gp_hvkg": "HVKG",
            "mf_gp_qehvi": "qEHVI",
        }
    
    # Get unique methods from data
    methods = df['method'].unique().tolist()
    
    # Use method_order if provided
    if method_order is not None:
        methods = [m for m in method_order if m in methods]
    
    # Prepare data for boxplot
    runtime_data = []
    labels = []
    for method in methods:
        method_data = df[df['method'] == method]['runtime'].values
        runtime_data.append(method_data)
        labels.append(method_labels.get(method, method))
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create boxplot without colors (grayscale)
    bp = ax.boxplot(
        runtime_data,
        tick_labels=labels,
        widths=boxplot_width,
        patch_artist=True,
        showmeans=True,
        meanline=True,
        boxprops=dict(linewidth=box_linewidth, edgecolor=box_edgecolor, facecolor='white'),
        whiskerprops=dict(linewidth=whisker_linewidth, color=box_edgecolor),
        capprops=dict(linewidth=whisker_linewidth, color=box_edgecolor),
        medianprops=dict(linewidth=0, color='none'),
        meanprops=dict(linewidth=median_linewidth, color=median_color, linestyle='-'),
        flierprops=dict(marker='o', markerfacecolor='gray', markersize=4, alpha=0.5, markeredgecolor='none')
    )
    
    # Set labels
    ax.set_ylabel(ylabel, fontsize=label_fontsize)
    
    if title:
        ax.set_title(title, fontsize=title_fontsize)
    
    # Format ticks
    ax.tick_params(axis='both', labelsize=tick_fontsize)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    
    # Grid
    if grid:
        ax.grid(axis='y', alpha=grid_alpha, linestyle='-', linewidth=0.5)
    
    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Adjust layout
    plt.subplots_adjust(left=subplot_left, right=subplot_right, top=subplot_top, bottom=subplot_bottom)
    
    if save_path:
        fig.savefig(save_path, dpi=dpi)
        print(f"Figure saved to: {save_path}")
    
    if show_plot:
        plt.show()
    
    return fig


if __name__ == "__main__":
    # # Regret plots
    # # Plot the legend separately
    # plot_legend(
    #     rows=1,
    #     columns=4,
    #     figsize=(1.42, 0.31),
    #     save_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "figs", "legend.pdf")
    # )
    regret_plot("data/BraninCurrin_best_nsga2_regret.csv", 
            show_legend=False,
            x_tick_step=10,
            ylabel="Best Log Regret",
            # subplot_left=0.3,
            title=r"$d=2, M=2$"
            )

    regret_plot("data/Park4D_best_nsga2_regret.csv", 
            show_legend=False,
            x_tick_step=10,
            y_tick_step=0.8,
            title=r"$d=4, M=2$"
            )
    
    regret_plot("data/HPOXGBoost_best_nsga2_regret.csv", 
                show_legend=False,
                x_tick_step=10,
                y_tick_step=0.5,
                title=r"$d=13, M=2$"
                )
    
    regret_plot("data/HPORanger_best_nsga2_regret.csv", 
                show_legend=False,
                x_tick_step=10,
                y_tick_step=0.2,
                # ylabel="Best Log Regret",
                # subplot_left=0.32,
                title=r"$d=5, M=3$"
                )
    
    regret_plot("data/Health_best_nsga2_regret.csv", 
            show_legend=False,
            x_tick_step=10,
            y_tick_step=0.8,
            title=r"$d=3, M=2, Q=1$"
            )
    
    regret_plot("data/AGVNavigation_best_nsga2_regret.csv", 
            show_legend=False,
            ylabel="Best Log Regret",
            # subplot_left=0.26,
            x_tick_step=10,
            y_tick_step=0.4,
            title=r"$d=26, M=2, Q=2$"
            )
    
    regret_plot("data/HPORangerConstrained_best_nsga2_regret.csv", 
            show_legend=False,
            x_tick_step=10,
            y_tick_step=0.5,
            title=r"$d=5, M=2, Q=1$"
            )
    
    regret_plot("data/HPOXGBoostConstrained_best_nsga2_regret.csv", 
            show_legend=False,
            x_tick_step=10,
            y_tick_step=0.5,
            title=r"$d=13, M=2, Q=1$"
            )
    
    # # # Violation plots
    # # Plot the legend separately
    # # plot_legend(
    # #     rows=1,
    # #     columns=4,
    # #     frame_linewidth=0.2,
    # #     legend_fontsize=16,
    # #     show_hatches=True,
    # #     use_bar_style=True,
    # #     figsize=(2.1, 0.5),
    # #     save_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "figs", "legend_vio.pdf")
    # # )
    violation_plot("data/HPORangerConstrained_violations.csv")
    violation_plot("data/AGVNavigation_violations.csv", ylabel="Violation Rate %", subplot_left=0.2)
    violation_plot("data/HPOXGBoostConstrained_violations.csv", ylabel="Violation Rate %", subplot_left=0.2)
    violation_plot("data/Health_violations.csv")

    # Fidelity and iteration plots
    # Plot the legend separately
    # plot_legend(
    #     use_bar_style=True,
    #     show_hatches=True,
    #     rows=1,
    #     columns=4,
    #     figsize=(1.5, 0.31),
    #     save_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "figs", "legend_bar.pdf")
    # )
    fidelity_iteration_plot("data/BraninCurrin_fidelity_iterations.csv", subplot_left= 0.22)
    fidelity_iteration_plot("data/AGVNavigation_fidelity_iterations.csv", discrete_fidelities=[0.2, 0.5, 1.0], subplot_left= 0.22)
    fidelity_iteration_plot("data/Park4D_fidelity_iterations.csv",
            fidelity_ylabel=" ",
            iteration_ylabel=" ",
            )
    fidelity_iteration_plot("data/HPOXGBoostConstrained_fidelity_iterations.csv",
            fidelity_ylabel=" ",
            iteration_ylabel=" ",
            )
    fidelity_iteration_plot("data/HPOXGBoost_fidelity_iterations.csv",
            fidelity_ylabel=" ",
            iteration_ylabel=" ",
            )
    fidelity_iteration_plot("data/HPORanger_fidelity_iterations.csv",
            fidelity_ylabel=" ",
            iteration_ylabel=" ",
            )
    fidelity_iteration_plot("data/HPORangerConstrained_fidelity_iterations.csv",
            fidelity_ylabel=" ",
            iteration_ylabel=" ",
            )
    fidelity_iteration_plot("data/Health_fidelity_iterations.csv",
            fidelity_ylabel=" ",
            iteration_ylabel=" ",
            )

    # # Runtime boxplots
    runtime_boxplot("data/BraninCurrin_runtime.csv", subplot_left=0.27)
    runtime_boxplot("data/Park4D_runtime.csv", ylabel=" ")
    runtime_boxplot("data/HPOXGBoost_runtime.csv", ylabel=" ")
    runtime_boxplot("data/HPORanger_runtime.csv", ylabel=" ")
    runtime_boxplot("data/AGVNavigation_runtime.csv", subplot_left=0.27)
    runtime_boxplot("data/Health_runtime.csv", ylabel=" ")
    runtime_boxplot("data/HPOXGBoostConstrained_runtime.csv", ylabel=" ")
    runtime_boxplot("data/HPORangerConstrained_runtime.csv", ylabel=" ")