"""Visualization utilities for budgeted optimization experiments.

This module provides publication-ready figures for:
1. HV/HVI convergence curves with bootstrap CIs
2. Pareto front comparisons across methods
3. Molecule gallery with structure annotations
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Sequence

import matplotlib.pyplot as plt
import numpy as np

from moltenflow.eval.pareto import pareto_front
from moltenflow.utils.logging import get_logger

if TYPE_CHECKING:
    import optuna

from .logger import load_experiment_logs
from .summary import aggregate_by_method, compute_method_summary

logger = get_logger(__name__)


# Publication-ready style settings
FIGURE_STYLE = {
    "font.family": "serif",
    "font.size": 10,
    "axes.labelsize": 11,
    "axes.titlesize": 12,
    "legend.fontsize": 9,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "axes.grid": True,
    "grid.alpha": 0.3,
    "grid.linestyle": "--",
}

# Method colors
METHOD_COLORS = {
    "bo_2gp": "#2563eb",  # Blue
    "bo_mogp": "#16a34a",  # Green
    "moltenflow": "#ea580c",  # Orange
    "gradient_ascent": "#dc2626",  # Red
}

# Method display names
METHOD_NAMES = {
    "bo_2gp": "BO (2-GP)",
    "bo_mogp": "BO (MOGP)",
    "moltenflow": "MoltenFlow",
    "gradient_ascent": "Gradient Ascent",
}

# Method markers
METHOD_MARKERS = {
    "bo_2gp": "s",  # Square
    "bo_mogp": "^",  # Triangle
    "moltenflow": "o",  # Circle
    "gradient_ascent": "d",  # Diamond
}

# Method display order (for consistent ordering in legends and plots)
METHOD_ORDER = ["moltenflow", "gradient_ascent", "bo_mogp", "bo_2gp"]


def setup_style() -> None:
    """Apply publication-ready style settings."""
    plt.rcParams.update(FIGURE_STYLE)


def get_method_color(method: str) -> str:
    """Get color for method."""
    return METHOD_COLORS.get(method, "#666666")


def get_method_name(method: str) -> str:
    """Get display name for method."""
    return METHOD_NAMES.get(method, method)


def get_method_marker(method: str) -> str:
    """Get marker for method."""
    return METHOD_MARKERS.get(method, "o")


def sort_methods(methods: Sequence[str]) -> list[str]:
    """Sort methods according to standard display order.

    Args:
        methods: List of method names

    Returns:
        Sorted list with methods in display order
    """

    # Sort by index in METHOD_ORDER, putting unknown methods at the end
    def sort_key(m: str) -> tuple[int, str]:
        try:
            return (METHOD_ORDER.index(m), m)
        except ValueError:
            return (len(METHOD_ORDER), m)

    return sorted(methods, key=sort_key)


def plot_hv_convergence(
    log_dir: str | Path,
    output_path: str | Path | None = None,
    methods: Sequence[str] | None = None,
    metric: str = "hv",
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    ci_type: str = "bootstrap",
    figsize: tuple[float, float] = (6, 4),
    seed: int = 42,
) -> plt.Figure:
    """Plot HV or HVI convergence curves with confidence bands.

    Args:
        log_dir: Directory containing optimization logs
        output_path: Path to save figure (optional)
        methods: Methods to include (default: all)
        metric: "hv" for hypervolume or "hvi" for improvement
        n_bootstrap: Bootstrap samples for CI (only used if ci_type="bootstrap")
        confidence: Confidence level (e.g., 0.95 for 95% CI)
        ci_type: Type of confidence band:
            - "bootstrap": Bootstrap percentile CI (wider, statistically rigorous)
            - "se": Standard error bands (mean +/- SE, tighter)
            - "se2": 2x standard error bands (mean +/- 2*SE, ~95% for normal)
        figsize: Figure size in inches
        seed: Random seed

    Returns:
        matplotlib Figure

    Note:
        The width of confidence bands depends on:
        - Variance across seeds (high variance = wide bands)
        - Number of seeds (more seeds = tighter bands via sqrt(n))
        - Confidence level (lower confidence = tighter bands)
        - CI type ("se" is tightest, "bootstrap" with 95% is widest)

        For publication, "bootstrap" with 95% CI is most rigorous.
        For clearer visualization, "se" or "se2" may be preferable.
    """
    setup_style()

    # Load and aggregate logs
    logs = load_experiment_logs(log_dir, methods=methods)
    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return plt.figure()

    method_dfs = aggregate_by_method(logs)

    # Sort methods for consistent legend order
    sorted_methods = sort_methods(list(method_dfs.keys()))

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    for method in sorted_methods:
        df = method_dfs[method]
        # Get curves for each seed
        seeds = df["seed"].unique()
        n_seeds = len(seeds)

        # Collect HV curves per seed
        all_curves = []
        max_steps = 0
        for s in seeds:
            seed_df = df[df["seed"] == s].sort_values("step")
            curve = seed_df[metric if metric == "hv" else "hv"].values
            all_curves.append(curve)
            max_steps = max(max_steps, len(curve))

        # Pad curves to same length
        padded_curves = []
        for curve in all_curves:
            if len(curve) < max_steps:
                # Pad with last value
                padded = np.concatenate([curve, np.full(max_steps - len(curve), curve[-1])])
            else:
                padded = curve
            padded_curves.append(padded)

        curves = np.array(padded_curves)  # Shape: (n_seeds, n_steps)
        steps = np.arange(max_steps)

        # Compute mean curve
        mean_curve = np.mean(curves, axis=0)

        # Compute confidence bands based on ci_type
        if ci_type == "se":
            # Standard error: mean +/- SE
            se = np.std(curves, axis=0, ddof=1) / np.sqrt(n_seeds)
            ci_lower = mean_curve - se
            ci_upper = mean_curve + se
        elif ci_type == "se2":
            # 2x standard error: mean +/- 2*SE (~95% for normal distribution)
            se = np.std(curves, axis=0, ddof=1) / np.sqrt(n_seeds)
            ci_lower = mean_curve - 2 * se
            ci_upper = mean_curve + 2 * se
        else:
            # Bootstrap CI (default)
            summary = compute_method_summary(df, n_bootstrap, confidence, seed)
            ci_lower = summary.hv_curve_ci_lower
            ci_upper = summary.hv_curve_ci_upper
            # Ensure same length
            if len(ci_lower) < max_steps:
                ci_lower = np.concatenate(
                    [ci_lower, np.full(max_steps - len(ci_lower), ci_lower[-1])]
                )
                ci_upper = np.concatenate(
                    [ci_upper, np.full(max_steps - len(ci_upper), ci_upper[-1])]
                )

        # Convert to HVI if needed
        if metric == "hvi":
            initial_hv = mean_curve[0] if len(mean_curve) > 0 else 0
            mean_curve = mean_curve - initial_hv
            ci_lower = ci_lower - initial_hv
            ci_upper = ci_upper - initial_hv
            ylabel = "Hypervolume Improvement"
        else:
            ylabel = "Hypervolume"

        color = get_method_color(method)
        label = f"{get_method_name(method)} ({mean_curve[-1]:.3f})"

        # Plot mean curve
        ax.plot(steps, mean_curve, color=color, linewidth=1.5, label=label)

        # Plot CI band
        ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2, linewidth=0)

    ax.set_xlabel("Oracle Calls")
    ax.set_ylabel(ylabel)
    ax.legend(loc="lower right")

    # Get init info from first log
    first_records = next(iter(logs.values()))
    init_method = first_records[0]["init"] if first_records else "unknown"
    budget = max(r["step"] for r in first_records) + 1 if first_records else 0
    ax.set_title(f"Optimization Convergence (init={init_method}, B={budget})")

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved HV convergence plot to {output_path}")

    return fig


def plot_pareto_comparison(
    log_dir: str | Path,
    output_path: str | Path | None = None,
    methods: Sequence[str] | None = None,
    show_initial: bool = True,
    ref_point: Sequence[float] = (0.0, -10.0),
    figsize: tuple[float, float] = (6, 5),
    representative_seed: int | None = None,
) -> plt.Figure:
    """Plot final Pareto front comparison across methods.

    Args:
        log_dir: Directory containing optimization logs
        output_path: Path to save figure
        methods: Methods to include
        show_initial: Whether to show initial molecules
        ref_point: Reference point to mark
        figsize: Figure size
        representative_seed: If set, use only this seed; otherwise combine all

    Returns:
        matplotlib Figure
    """
    setup_style()

    # Load logs
    logs = load_experiment_logs(log_dir, methods=methods)
    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return plt.figure()

    method_dfs = aggregate_by_method(logs)

    fig, ax = plt.subplots(figsize=figsize)

    # Plot initial molecules (from first few steps)
    if show_initial:
        all_initial_qed = []
        all_initial_neg_sa = []

        for method, df in method_dfs.items():
            # Get initial molecules (step < n_init, approximated)
            seeds = df["seed"].unique()
            for s in seeds:
                seed_df = df[df["seed"] == s]
                n_init = len(seed_df[seed_df["step"] < 20])  # Approximate
                initial = seed_df.head(n_init)
                all_initial_qed.extend(initial["qed"].tolist())
                all_initial_neg_sa.extend(initial["neg_sa"].tolist())
            break  # Same initial for all methods

        if all_initial_qed:
            ax.scatter(
                all_initial_qed,
                all_initial_neg_sa,
                c="lightgray",
                s=20,
                alpha=0.5,
                label="Initial",
                zorder=1,
            )

    # Plot final Pareto fronts per method
    for method, df in method_dfs.items():
        if representative_seed is not None:
            df = df[df["seed"] == representative_seed]

        if len(df) == 0:
            continue

        # Get final valid molecules
        valid_df = df[df["valid"]]
        if len(valid_df) == 0:
            continue

        objectives = valid_df[["qed", "neg_sa"]].values

        # Find Pareto front
        mask = pareto_front(objectives, sense=["max", "max"])
        pareto_pts = objectives[mask]

        # Sort by QED for line plot
        sort_idx = np.argsort(pareto_pts[:, 0])
        pareto_pts = pareto_pts[sort_idx]

        color = get_method_color(method)
        marker = get_method_marker(method)
        label = get_method_name(method)

        # Plot Pareto points
        ax.scatter(
            pareto_pts[:, 0],
            pareto_pts[:, 1],
            c=color,
            s=50,
            marker=marker,
            label=label,
            zorder=3,
            edgecolor="white",
            linewidth=0.5,
        )

        # Plot step-function Pareto front
        if len(pareto_pts) > 1:
            # Create step function
            x_steps = []
            y_steps = []
            for i in range(len(pareto_pts)):
                x_steps.append(pareto_pts[i, 0])
                y_steps.append(pareto_pts[i, 1])
                if i < len(pareto_pts) - 1:
                    x_steps.append(pareto_pts[i + 1, 0])
                    y_steps.append(pareto_pts[i, 1])

            ax.plot(x_steps, y_steps, c=color, linewidth=1, alpha=0.7, zorder=2)

    # Mark reference point
    ax.axvline(ref_point[0], color="gray", linestyle="--", alpha=0.5, linewidth=0.8)
    ax.axhline(ref_point[1], color="gray", linestyle="--", alpha=0.5, linewidth=0.8)

    ax.set_xlabel("QED (Drug-likeness)")
    ax.set_ylabel("-SA (Synthetic Accessibility)")
    ax.legend(loc="lower right")
    ax.set_title("Final Pareto Front Comparison")

    # Set axis limits
    ax.set_xlim(0, 1)
    ax.set_ylim(-10, -1)

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved Pareto comparison plot to {output_path}")

    return fig


def plot_molecule_gallery(
    log_dir: str | Path,
    output_path: str | Path | None = None,
    methods: Sequence[str] | None = None,
    n_molecules: int = 4,
    figsize: tuple[float, float] | None = None,
    seed: int | None = None,
    sample_across_seeds: bool = False,
) -> plt.Figure:
    """Plot gallery of representative molecules from each method.

    Requires RDKit for structure rendering.

    Args:
        log_dir: Directory containing optimization logs
        output_path: Path to save figure
        methods: Methods to include
        n_molecules: Number of molecules per method
        figsize: Figure size (auto if None)
        seed: Specific seed to use (ignored if sample_across_seeds=True)
        sample_across_seeds: Sample molecules from multiple seeds for diversity

    Returns:
        matplotlib Figure
    """
    try:
        from rdkit import Chem
        from rdkit.Chem import Draw
    except ImportError:
        logger.error("RDKit required for molecule gallery. Install with: pip install rdkit")
        return plt.figure()

    setup_style()

    # Load logs
    logs = load_experiment_logs(log_dir, methods=methods)
    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return plt.figure()

    method_dfs = aggregate_by_method(logs)

    # Determine methods to plot
    method_list = list(method_dfs.keys())
    n_methods = len(method_list)

    if figsize is None:
        figsize = (3 * n_molecules, 3.5 * n_methods)

    fig, axes = plt.subplots(n_methods, n_molecules, figsize=figsize)

    if n_methods == 1:
        axes = axes.reshape(1, -1)
    if n_molecules == 1:
        axes = axes.reshape(-1, 1)

    for row, method in enumerate(method_list):
        df = method_dfs[method]

        # Collect Pareto molecules
        all_pareto_smiles = []
        all_pareto_objectives = []

        if sample_across_seeds:
            # Sample from multiple seeds for diversity
            seeds = df["seed"].unique()
            for s in seeds:
                seed_df = df[df["seed"] == s]
                valid_df = seed_df[seed_df["valid"]]

                if len(valid_df) == 0:
                    continue

                objectives = valid_df[["qed", "neg_sa"]].values
                smiles_list = valid_df["smiles"].tolist()

                # Find Pareto front
                mask = pareto_front(objectives, sense=["max", "max"])
                pareto_idx = np.where(mask)[0]

                # Add to collection
                all_pareto_smiles.extend([smiles_list[i] for i in pareto_idx])
                all_pareto_objectives.extend(objectives[pareto_idx].tolist())
        else:
            # Use single seed
            if seed is not None:
                df = df[df["seed"] == seed]
            else:
                # Use first seed
                first_seed = df["seed"].iloc[0]
                df = df[df["seed"] == first_seed]

            valid_df = df[df["valid"]]
            if len(valid_df) == 0:
                continue

            objectives = valid_df[["qed", "neg_sa"]].values
            smiles_list = valid_df["smiles"].tolist()

            # Find Pareto front
            mask = pareto_front(objectives, sense=["max", "max"])
            pareto_idx = np.where(mask)[0]
            all_pareto_smiles = [smiles_list[i] for i in pareto_idx]
            all_pareto_objectives = objectives[pareto_idx].tolist()

        if len(all_pareto_smiles) == 0:
            for col in range(n_molecules):
                axes[row, col].text(0.5, 0.5, "No valid\nmolecules", ha="center", va="center")
                axes[row, col].axis("off")
            continue

        all_pareto_objectives = np.array(all_pareto_objectives)

        # Select diverse molecules (spread across QED range)
        if len(all_pareto_smiles) > n_molecules:
            # Sort by QED and sample evenly
            sort_idx = np.argsort(all_pareto_objectives[:, 0])
            step = max(1, len(sort_idx) // n_molecules)
            selected_idx = sort_idx[::step][:n_molecules]
        else:
            selected_idx = range(len(all_pareto_smiles))

        for col in range(n_molecules):
            ax = axes[row, col]

            if col < len(selected_idx):
                idx = selected_idx[col]
                smi = all_pareto_smiles[idx]
                qed = all_pareto_objectives[idx, 0]
                neg_sa = all_pareto_objectives[idx, 1]

                # Render molecule
                mol = Chem.MolFromSmiles(smi)
                if mol is not None:
                    img = Draw.MolToImage(mol, size=(300, 300))
                    ax.imshow(img)
                    ax.set_title(f"QED={qed:.2f}, SA={-neg_sa:.1f}", fontsize=9)
                else:
                    ax.text(0.5, 0.5, "Invalid\nSMILES", ha="center", va="center", fontsize=8)
            else:
                ax.text(0.5, 0.5, "N/A", ha="center", va="center", fontsize=8)

            ax.axis("off")

            # Add method label to first column with better styling
            if col == 0:
                # Use text annotation for better control
                method_name = get_method_name(method)
                ax.text(
                    -0.3,
                    0.5,
                    method_name,
                    transform=ax.transAxes,
                    fontsize=11,
                    fontweight="bold",
                    va="center",
                    ha="right",
                    rotation=0,
                )

    # Add overall title
    fig.suptitle(
        "Representative Pareto-Optimal Molecules by Method",
        fontsize=13,
        fontweight="bold",
        y=0.995,
    )

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if output_path:
        fig.savefig(output_path, bbox_inches="tight")
        logger.info(f"Saved molecule gallery to {output_path}")

    return fig


def plot_pareto_density(
    log_dir: str | Path,
    output_path: str | Path | None = None,
    methods: Sequence[str] | None = None,
    grid_resolution: int = 50,
    show_initial: bool = True,
    figsize: tuple[float, float] = (8, 6),
) -> plt.Figure:
    """Plot 2D density of Pareto fronts across seeds using KDE.

    Args:
        log_dir: Directory containing optimization logs
        output_path: Path to save figure
        methods: Methods to include
        grid_resolution: Grid resolution for KDE
        show_initial: Whether to show initial molecules
        figsize: Figure size

    Returns:
        matplotlib Figure
    """
    try:
        from scipy.stats import gaussian_kde
    except ImportError:
        logger.error("scipy required for density plots. Install with: pip install scipy")
        return plt.figure()

    setup_style()

    # Load logs
    logs = load_experiment_logs(log_dir, methods=methods)
    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return plt.figure()

    method_dfs = aggregate_by_method(logs)

    fig, ax = plt.subplots(figsize=figsize)

    # Plot initial molecules density (if requested)
    if show_initial:
        all_initial_qed = []
        all_initial_neg_sa = []

        for method, df in method_dfs.items():
            # Get initial molecules (approximate as first 20 steps)
            seeds = df["seed"].unique()
            for s in seeds:
                seed_df = df[df["seed"] == s]
                n_init = min(20, len(seed_df) // 5)
                initial = seed_df.head(n_init)
                valid_initial = initial[initial["valid"]]
                all_initial_qed.extend(valid_initial["qed"].tolist())
                all_initial_neg_sa.extend(valid_initial["neg_sa"].tolist())
            break  # Same initial for all methods

        if len(all_initial_qed) > 10:
            try:
                initial_pts = np.vstack([all_initial_qed, all_initial_neg_sa])
                kde_initial = gaussian_kde(initial_pts)

                # Create grid
                qed_grid = np.linspace(0, 1, grid_resolution)
                sa_grid = np.linspace(-10, -1, grid_resolution)
                QED, SA = np.meshgrid(qed_grid, sa_grid)
                grid_coords = np.vstack([QED.ravel(), SA.ravel()])

                # Evaluate KDE
                density_initial = kde_initial(grid_coords).reshape(QED.shape)

                # Plot contours
                ax.contourf(QED, SA, density_initial, levels=5, cmap="Greys", alpha=0.3, zorder=1)
            except Exception as e:
                logger.warning(f"Failed to compute initial density: {e}")

    # Plot final Pareto density per method
    for method, df in method_dfs.items():
        # Get all final valid Pareto molecules across seeds
        all_pareto_qed = []
        all_pareto_neg_sa = []

        seeds = df["seed"].unique()
        for s in seeds:
            seed_df = df[df["seed"] == s]
            valid_df = seed_df[seed_df["valid"]]

            if len(valid_df) == 0:
                continue

            objectives = valid_df[["qed", "neg_sa"]].values

            # Find Pareto front
            mask = pareto_front(objectives, sense=["max", "max"])
            pareto_pts = objectives[mask]

            all_pareto_qed.extend(pareto_pts[:, 0].tolist())
            all_pareto_neg_sa.extend(pareto_pts[:, 1].tolist())

        if len(all_pareto_qed) < 5:
            logger.warning(f"Not enough Pareto points for {method} density plot")
            continue

        try:
            pareto_pts = np.vstack([all_pareto_qed, all_pareto_neg_sa])
            kde = gaussian_kde(pareto_pts)

            # Create grid
            qed_grid = np.linspace(0, 1, grid_resolution)
            sa_grid = np.linspace(-10, -1, grid_resolution)
            QED, SA = np.meshgrid(qed_grid, sa_grid)
            grid_coords = np.vstack([QED.ravel(), SA.ravel()])

            # Evaluate KDE
            density = kde(grid_coords).reshape(QED.shape)

            color = get_method_color(method)
            label = get_method_name(method)

            # Plot contours
            contours = ax.contour(
                QED, SA, density, levels=5, colors=[color], linewidths=1.5, alpha=0.8, zorder=3
            )

            # Add label to first contour
            if len(contours.collections) > 0:
                contours.collections[0].set_label(label)

        except Exception as e:
            logger.warning(f"Failed to compute density for {method}: {e}")

    ax.set_xlabel("QED (Drug-likeness)")
    ax.set_ylabel("-SA (Synthetic Accessibility)")
    ax.legend(loc="lower right")
    ax.set_title("Pareto Front Density (across all seeds)")

    # Set axis limits
    ax.set_xlim(0, 1)
    ax.set_ylim(-10, -1)

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved Pareto density plot to {output_path}")

    return fig


def plot_example_runs(
    log_dir: str | Path,
    output_path: str | Path | None = None,
    methods: Sequence[str] | None = None,
    n_best: int = 2,
    n_worst: int = 2,
    ref_point: Sequence[float] = (0.0, -10.0),
    figsize: tuple[float, float] | None = None,
) -> plt.Figure:
    """Plot grid of individual run examples (best and worst by HVI).

    Args:
        log_dir: Directory containing optimization logs
        output_path: Path to save figure
        methods: Methods to include
        n_best: Number of best runs per method
        n_worst: Number of worst runs per method
        ref_point: Reference point to mark
        figsize: Figure size (auto if None)

    Returns:
        matplotlib Figure
    """
    setup_style()

    # Load logs
    logs = load_experiment_logs(log_dir, methods=methods)
    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return plt.figure()

    method_dfs = aggregate_by_method(logs)

    # Determine methods to plot
    method_list = list(method_dfs.keys())
    n_methods = len(method_list)
    n_cols = n_best + n_worst

    if figsize is None:
        figsize = (4 * n_cols, 3.5 * n_methods)

    fig, axes = plt.subplots(n_methods, n_cols, figsize=figsize)

    if n_methods == 1:
        axes = axes.reshape(1, -1)
    if n_cols == 1:
        axes = axes.reshape(-1, 1)

    for row, method in enumerate(method_list):
        df = method_dfs[method]
        seeds = df["seed"].unique()

        # Compute final HVI per seed
        seed_hvi = {}
        for s in seeds:
            seed_df = df[df["seed"] == s]
            if len(seed_df) > 0:
                seed_hvi[s] = seed_df["hvi"].iloc[-1]

        # Select best and worst seeds
        sorted_seeds = sorted(seed_hvi.items(), key=lambda x: x[1], reverse=True)
        best_seeds = [s for s, _ in sorted_seeds[:n_best]]
        worst_seeds = [s for s, _ in sorted_seeds[-n_worst:]]

        selected_seeds = best_seeds + worst_seeds

        for col, seed in enumerate(selected_seeds):
            ax = axes[row, col]
            seed_df = df[df["seed"] == seed]

            # Get initial and final molecules
            n_init = min(20, len(seed_df) // 5)
            initial_df = seed_df.head(n_init)
            final_df = seed_df.tail(len(seed_df) - n_init)

            # Plot initial
            initial_valid = initial_df[initial_df["valid"]]
            if len(initial_valid) > 0:
                ax.scatter(
                    initial_valid["qed"],
                    initial_valid["neg_sa"],
                    c="lightgray",
                    s=20,
                    alpha=0.5,
                    label="Initial",
                    zorder=1,
                )

            # Plot final Pareto
            valid_final = final_df[final_df["valid"]]
            if len(valid_final) > 0:
                objectives = valid_final[["qed", "neg_sa"]].values
                mask = pareto_front(objectives, sense=["max", "max"])
                pareto_pts = objectives[mask]

                # Sort by QED
                sort_idx = np.argsort(pareto_pts[:, 0])
                pareto_pts = pareto_pts[sort_idx]

                color = get_method_color(method)
                ax.scatter(
                    pareto_pts[:, 0],
                    pareto_pts[:, 1],
                    c=color,
                    s=40,
                    marker=get_method_marker(method),
                    label="Final Pareto",
                    zorder=3,
                    edgecolor="white",
                    linewidth=0.5,
                )

                # Plot step function
                if len(pareto_pts) > 1:
                    x_steps = []
                    y_steps = []
                    for i in range(len(pareto_pts)):
                        x_steps.append(pareto_pts[i, 0])
                        y_steps.append(pareto_pts[i, 1])
                        if i < len(pareto_pts) - 1:
                            x_steps.append(pareto_pts[i + 1, 0])
                            y_steps.append(pareto_pts[i, 1])
                    ax.plot(x_steps, y_steps, c=color, linewidth=1, alpha=0.5, zorder=2)

            # Mark reference point
            ax.axvline(ref_point[0], color="gray", linestyle="--", alpha=0.3, linewidth=0.8)
            ax.axhline(ref_point[1], color="gray", linestyle="--", alpha=0.3, linewidth=0.8)

            # Set labels
            ax.set_xlim(0, 1)
            ax.set_ylim(-10, -1)

            # Title with seed and HVI
            hvi = seed_hvi.get(seed, 0)
            rank = "Best" if seed in best_seeds else "Worst"
            ax.set_title(f"{rank} (seed={seed}, HVI={hvi:.3f})", fontsize=9)

            if col == 0:
                ax.set_ylabel(get_method_name(method), fontsize=10)
            if row == n_methods - 1:
                ax.set_xlabel("QED", fontsize=9)

            # Only show legend on first subplot
            if row == 0 and col == 0:
                ax.legend(fontsize=7, loc="lower right")

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved example runs plot to {output_path}")

    return fig


def plot_runtime_comparison(
    log_dir: str | Path,
    output_path: str | Path | None = None,
    methods: Sequence[str] | None = None,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    figsize: tuple[float, float] = (6, 4),
    seed: int = 42,
) -> plt.Figure:
    """Plot bar chart comparing runtime across methods with error bars.

    Args:
        log_dir: Directory containing optimization logs
        output_path: Path to save figure (optional)
        methods: Methods to include (default: all)
        n_bootstrap: Bootstrap samples for CI
        confidence: Confidence level
        figsize: Figure size in inches
        seed: Random seed

    Returns:
        matplotlib Figure
    """
    from .summary import bootstrap_ci, load_runtime_data

    setup_style()

    # Load runtime data
    runtime_data = load_runtime_data(log_dir)

    if not runtime_data:
        logger.warning(f"No timing data found in {log_dir}")
        return plt.figure()

    # Filter methods if specified
    if methods:
        runtime_data = {m: v for m, v in runtime_data.items() if m in methods}

    if not runtime_data:
        logger.warning("No matching methods found in timing data")
        return plt.figure()

    # Compute statistics for each method
    method_stats = []

    for method, runtimes in runtime_data.items():
        if len(runtimes) == 0:
            continue

        runtimes_arr = np.array(runtimes)
        ci = bootstrap_ci(runtimes_arr, n_bootstrap, confidence, seed)

        method_stats.append(
            {
                "method": method,
                "method_name": get_method_name(method),
                "mean": ci.mean,
                "ci_lower": ci.mean - ci.ci_lower,
                "ci_upper": ci.ci_upper - ci.mean,
                "color": get_method_color(method),
            }
        )

    if not method_stats:
        logger.warning("No valid runtime data to plot")
        return plt.figure()

    # Sort by mean runtime (least to greatest)
    method_stats = sorted(method_stats, key=lambda x: x["mean"])

    # Extract sorted lists
    method_names = [s["method_name"] for s in method_stats]
    means = [s["mean"] for s in method_stats]
    ci_lowers = [s["ci_lower"] for s in method_stats]
    ci_uppers = [s["ci_upper"] for s in method_stats]
    colors = [s["color"] for s in method_stats]

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Create bar chart
    x = np.arange(len(method_names))
    bars = ax.bar(x, means, color=colors, alpha=0.8, edgecolor="black", linewidth=0.5)

    # Add error bars
    ax.errorbar(
        x,
        means,
        yerr=[ci_lowers, ci_uppers],
        fmt="none",
        ecolor="black",
        capsize=5,
        capthick=1.5,
        linewidth=1.5,
    )

    # Customize plot
    ax.set_xticks(x)
    ax.set_xticklabels(method_names, rotation=15, ha="right")
    ax.set_ylabel("Runtime (seconds)")
    ax.set_title(f"Runtime Comparison ({int(confidence * 100)}% CI)")

    # Add value labels on bars
    for bar, mean in zip(bars, means):
        height = bar.get_height()
        ax.annotate(
            f"{mean:.1f}s",
            xy=(bar.get_x() + bar.get_width() / 2, height),
            xytext=(0, 3),
            textcoords="offset points",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Set y-axis to start at 0
    ax.set_ylim(bottom=0)

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved runtime comparison plot to {output_path}")

    return fig


def plot_all_figures(
    log_dir: str | Path,
    output_dir: str | Path,
    init: str = "random",
    budget: int | None = None,
    methods: Sequence[str] | None = None,
    format: str = "pdf",
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    ci_type: str = "bootstrap",
) -> dict[str, Path]:
    """Generate all standard figures from experiment logs.

    Args:
        log_dir: Directory containing optimization logs
        output_dir: Directory for output figures
        init: Initialization method to filter by
        budget: Budget to filter by (optional)
        methods: Methods to include
        format: Output format ("pdf", "png", or "both")
        n_bootstrap: Bootstrap samples for CI computation
        confidence: Confidence level for CI bands
        ci_type: "bootstrap" for bootstrap CI, "se" for standard error bands

    Returns:
        Dictionary mapping figure name to output path
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    outputs = {}
    suffix = f"_{init}" if init else ""
    if budget:
        suffix += f"_B{budget}"

    # HV convergence
    hv_path = output_dir / f"hv_convergence{suffix}.{format}"
    plot_hv_convergence(
        log_dir,
        hv_path,
        methods=methods,
        metric="hv",
        n_bootstrap=n_bootstrap,
        confidence=confidence,
        ci_type=ci_type,
    )
    outputs["hv_convergence"] = hv_path

    # HVI convergence
    hvi_path = output_dir / f"hvi_convergence{suffix}.{format}"
    plot_hv_convergence(
        log_dir,
        hvi_path,
        methods=methods,
        metric="hvi",
        n_bootstrap=n_bootstrap,
        confidence=confidence,
        ci_type=ci_type,
    )
    outputs["hvi_convergence"] = hvi_path

    # Runtime comparison
    runtime_path = output_dir / f"runtime_comparison{suffix}.{format}"
    try:
        plot_runtime_comparison(
            log_dir,
            runtime_path,
            methods=methods,
            n_bootstrap=n_bootstrap,
            confidence=confidence,
        )
        outputs["runtime_comparison"] = runtime_path
    except Exception as e:
        logger.warning(f"Could not generate runtime comparison (no timing data?): {e}")

    # Pareto comparison
    pareto_path = output_dir / f"pareto_comparison{suffix}.{format}"
    plot_pareto_comparison(log_dir, pareto_path, methods=methods)
    outputs["pareto_comparison"] = pareto_path

    # Molecule gallery
    gallery_path = output_dir / f"molecule_gallery{suffix}.png"
    plot_molecule_gallery(log_dir, gallery_path, methods=methods)
    outputs["molecule_gallery"] = gallery_path

    logger.info(f"Generated {len(outputs)} figures in {output_dir}")

    return outputs


# =============================================================================
# Hyperparameter Optimization Plots
# =============================================================================


def plot_hpo_history(
    study: "optuna.Study",
    output_path: str | Path | None = None,
    method: str = "method",
    figsize: tuple[float, float] = (8, 5),
) -> plt.Figure:
    """Plot optimization history (HVI vs trial number).

    Args:
        study: Optuna study object
        output_path: Path to save figure
        method: Method name for title
        figsize: Figure size in inches

    Returns:
        matplotlib Figure
    """
    setup_style()

    fig, ax = plt.subplots(figsize=figsize)

    # Get trial data
    trials = study.trials
    trial_numbers = [t.number for t in trials if t.value is not None]
    values = [t.value for t in trials if t.value is not None]

    if not trial_numbers:
        logger.warning("No completed trials to plot")
        return fig

    # Plot individual trials
    color = get_method_color(method)
    ax.scatter(trial_numbers, values, c=color, alpha=0.6, s=30, label="Trials")

    # Plot best value so far
    best_values = []
    best = float("-inf")
    for v in values:
        best = max(best, v)
        best_values.append(best)

    ax.plot(trial_numbers, best_values, c=color, linewidth=2, label="Best so far")

    ax.set_xlabel("Trial Number")
    ax.set_ylabel("Mean HVI")
    ax.set_title(f"Optimization History - {get_method_name(method)}")
    ax.legend(loc="lower right")

    # Highlight best trial
    best_idx = np.argmax(values)
    ax.axhline(values[best_idx], linestyle="--", color="gray", alpha=0.5)
    ax.annotate(
        f"Best: {values[best_idx]:.4f}",
        xy=(trial_numbers[best_idx], values[best_idx]),
        xytext=(10, 10),
        textcoords="offset points",
        fontsize=9,
    )

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved HPO history plot to {output_path}")

    return fig


def plot_hpo_importance(
    study: "optuna.Study",
    output_path: str | Path | None = None,
    method: str = "method",
    figsize: tuple[float, float] = (8, 5),
) -> plt.Figure:
    """Plot parameter importance analysis.

    Args:
        study: Optuna study object
        output_path: Path to save figure
        method: Method name for title
        figsize: Figure size in inches

    Returns:
        matplotlib Figure
    """
    try:
        from optuna.importance import get_param_importances
    except ImportError:
        logger.error("Optuna required for importance plot")
        return plt.figure()

    setup_style()

    fig, ax = plt.subplots(figsize=figsize)

    try:
        importances = get_param_importances(study)
    except Exception as e:
        logger.warning(f"Could not compute importances: {e}")
        ax.text(
            0.5,
            0.5,
            "Insufficient data for importance analysis",
            ha="center",
            va="center",
            transform=ax.transAxes,
        )
        return fig

    if not importances:
        ax.text(
            0.5,
            0.5,
            "No importance data available",
            ha="center",
            va="center",
            transform=ax.transAxes,
        )
        return fig

    # Sort by importance
    params = list(importances.keys())
    values = list(importances.values())
    sorted_idx = np.argsort(values)[::-1]

    params = [params[i] for i in sorted_idx]
    values = [values[i] for i in sorted_idx]

    # Create horizontal bar chart
    color = get_method_color(method)
    y_pos = np.arange(len(params))

    ax.barh(y_pos, values, color=color, alpha=0.8)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(params)
    ax.set_xlabel("Importance")
    ax.set_title(f"Parameter Importance - {get_method_name(method)}")
    ax.invert_yaxis()

    # Add value labels
    for i, v in enumerate(values):
        ax.text(v + 0.01, i, f"{v:.3f}", va="center", fontsize=9)

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved HPO importance plot to {output_path}")

    return fig


def plot_hpo_parallel_coordinate(
    study: "optuna.Study",
    output_path: str | Path | None = None,
    method: str = "method",
    figsize: tuple[float, float] = (10, 6),
) -> plt.Figure:
    """Plot parallel coordinate visualization of trials.

    Args:
        study: Optuna study object
        output_path: Path to save figure
        method: Method name for title
        figsize: Figure size in inches

    Returns:
        matplotlib Figure
    """
    try:
        import optuna.visualization as vis
    except ImportError:
        logger.error("Optuna required for parallel coordinate plot")
        return plt.figure()

    setup_style()

    # Use Optuna's built-in parallel coordinate plot
    try:
        # Note: We don't use Optuna's plotly figure directly, we create a matplotlib version
        _ = vis.plot_parallel_coordinate(study)  # Validates that study has valid data

        # Convert to matplotlib (simplified version)
        fig, ax = plt.subplots(figsize=figsize)

        # Get trial data for manual plotting
        trials = [t for t in study.trials if t.value is not None and t.params]

        if not trials:
            ax.text(
                0.5, 0.5, "No completed trials", ha="center", va="center", transform=ax.transAxes
            )
            return fig

        # Get all param names
        param_names = list(trials[0].params.keys())
        n_params = len(param_names)

        if n_params == 0:
            ax.text(
                0.5, 0.5, "No parameters to plot", ha="center", va="center", transform=ax.transAxes
            )
            return fig

        # Normalize parameters for plotting
        param_ranges = {}
        for name in param_names:
            values = [t.params.get(name) for t in trials if t.params.get(name) is not None]
            # Handle categorical params
            if values and isinstance(values[0], (bool, str, type(None))):
                unique = list(set(v for v in values if v is not None))
                param_ranges[name] = {"type": "categorical", "values": unique}
            else:
                numeric_vals = [v for v in values if isinstance(v, (int, float))]
                if numeric_vals:
                    param_ranges[name] = {
                        "type": "numeric",
                        "min": min(numeric_vals),
                        "max": max(numeric_vals),
                    }

        # Plot each trial as a line
        x_coords = np.arange(n_params + 1)  # +1 for objective value

        # Normalize objective values
        obj_values = [t.value for t in trials]
        obj_min, obj_max = min(obj_values), max(obj_values)
        obj_range = obj_max - obj_min if obj_max > obj_min else 1.0

        # Color by objective value
        cmap = plt.cm.viridis
        norm = plt.Normalize(vmin=obj_min, vmax=obj_max)

        for trial in trials:
            y_coords = []

            for name in param_names:
                val = trial.params.get(name)
                pr = param_ranges.get(name, {})

                if pr.get("type") == "categorical":
                    unique = pr.get("values", [])
                    if val in unique:
                        y_coords.append(unique.index(val) / max(len(unique) - 1, 1))
                    else:
                        y_coords.append(0.5)
                elif pr.get("type") == "numeric":
                    pmin, pmax = pr.get("min", 0), pr.get("max", 1)
                    prange = pmax - pmin if pmax > pmin else 1.0
                    y_coords.append((val - pmin) / prange if val is not None else 0.5)
                else:
                    y_coords.append(0.5)

            # Add normalized objective value
            y_coords.append((trial.value - obj_min) / obj_range)

            ax.plot(x_coords, y_coords, c=cmap(norm(trial.value)), alpha=0.5, linewidth=1)

        # Set axis labels
        ax.set_xticks(x_coords)
        ax.set_xticklabels(param_names + ["HVI"], rotation=45, ha="right")
        ax.set_ylabel("Normalized Value")
        ax.set_title(f"Parallel Coordinates - {get_method_name(method)}")

        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        plt.colorbar(sm, ax=ax, label="HVI")

        plt.tight_layout()

    except Exception as e:
        logger.warning(f"Failed to create parallel coordinate plot: {e}")
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(
            0.5,
            0.5,
            f"Plot generation failed: {e}",
            ha="center",
            va="center",
            transform=ax.transAxes,
        )

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved HPO parallel coordinate plot to {output_path}")

    return fig


def plot_hpo_contour(
    study: "optuna.Study",
    output_path: str | Path | None = None,
    method: str = "method",
    param1: str = "gamma",
    param2: str = "sigma",
    figsize: tuple[float, float] = (8, 6),
) -> plt.Figure:
    """Plot 2D contour of parameter interactions.

    Args:
        study: Optuna study object
        output_path: Path to save figure
        method: Method name for title
        param1: First parameter name
        param2: Second parameter name
        figsize: Figure size in inches

    Returns:
        matplotlib Figure
    """
    setup_style()

    fig, ax = plt.subplots(figsize=figsize)

    # Get trial data
    trials = [t for t in study.trials if t.value is not None and t.params]

    if not trials:
        ax.text(0.5, 0.5, "No completed trials", ha="center", va="center", transform=ax.transAxes)
        return fig

    # Extract parameter values
    p1_vals = []
    p2_vals = []
    obj_vals = []

    for trial in trials:
        p1 = trial.params.get(param1)
        p2 = trial.params.get(param2)

        if (
            p1 is not None
            and p2 is not None
            and isinstance(p1, (int, float))
            and isinstance(p2, (int, float))
        ):
            p1_vals.append(p1)
            p2_vals.append(p2)
            obj_vals.append(trial.value)

    if len(p1_vals) < 3:
        ax.text(
            0.5,
            0.5,
            f"Insufficient data for {param1} vs {param2}",
            ha="center",
            va="center",
            transform=ax.transAxes,
        )
        return fig

    p1_vals = np.array(p1_vals)
    p2_vals = np.array(p2_vals)
    obj_vals = np.array(obj_vals)

    # Create scatter plot with color representing objective
    scatter = ax.scatter(
        p1_vals,
        p2_vals,
        c=obj_vals,
        cmap="viridis",
        s=50,
        alpha=0.8,
        edgecolors="black",
        linewidth=0.5,
    )

    # Try to create contour if enough points
    if len(p1_vals) >= 10:
        try:
            from scipy.interpolate import griddata

            # Create grid
            p1_grid = np.linspace(p1_vals.min(), p1_vals.max(), 50)
            p2_grid = np.linspace(p2_vals.min(), p2_vals.max(), 50)
            P1, P2 = np.meshgrid(p1_grid, p2_grid)

            # Interpolate
            Z = griddata((p1_vals, p2_vals), obj_vals, (P1, P2), method="cubic")

            # Plot contours
            ax.contour(P1, P2, Z, levels=10, colors="gray", alpha=0.5, linewidths=0.5)
            ax.contourf(P1, P2, Z, levels=20, cmap="viridis", alpha=0.3)

        except Exception as e:
            logger.debug(f"Could not create contours: {e}")

    # Highlight best point
    best_idx = np.argmax(obj_vals)
    ax.scatter(
        [p1_vals[best_idx]],
        [p2_vals[best_idx]],
        c="red",
        s=200,
        marker="*",
        edgecolors="black",
        linewidth=1,
        label=f"Best: {obj_vals[best_idx]:.4f}",
        zorder=10,
    )

    ax.set_xlabel(param1)
    ax.set_ylabel(param2)
    ax.set_title(f"Parameter Contour - {get_method_name(method)}")
    ax.legend(loc="upper right")

    # Add colorbar
    plt.colorbar(scatter, ax=ax, label="HVI")

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved HPO contour plot to {output_path}")

    return fig


# =============================================================================
# Gamma Sweep Plots
# =============================================================================


def plot_gamma_sweep_curve(
    gammas: np.ndarray,
    values: np.ndarray,
    errors: np.ndarray | None = None,
    output_path: str | Path | None = None,
    ylabel: str = "Penalized HVI",
    title: str = "Gamma Sweep",
    figsize: tuple[float, float] = (8, 5),
    log_scale: bool = True,
    highlight_best: bool = True,
) -> plt.Figure:
    """Plot gamma sweep curve with optional error bars.

    Args:
        gammas: Array of gamma values
        values: Array of metric values (mean per gamma)
        errors: Optional array of standard deviations
        output_path: Path to save figure
        ylabel: Y-axis label
        title: Plot title
        figsize: Figure size in inches
        log_scale: Whether to use log scale for x-axis
        highlight_best: Whether to highlight the best gamma

    Returns:
        matplotlib Figure
    """
    setup_style()

    fig, ax = plt.subplots(figsize=figsize)
    color = get_method_color("moltenflow")

    if errors is not None:
        ax.errorbar(
            gammas,
            values,
            yerr=errors,
            fmt="o-",
            capsize=5,
            capthick=2,
            linewidth=2,
            markersize=8,
            color=color,
            label=ylabel,
        )
    else:
        ax.plot(gammas, values, "o-", linewidth=2, markersize=8, color=color, label=ylabel)

    if log_scale:
        ax.set_xscale("log")

    ax.set_xlabel("Gamma ($\\gamma$)", fontsize=12)
    ax.set_ylabel(ylabel, fontsize=12)
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.3, linestyle="--")
    ax.legend(loc="best")

    # Highlight best gamma
    if highlight_best and len(values) > 0:
        best_idx = np.argmax(values)
        best_gamma = gammas[best_idx]
        best_value = values[best_idx]
        ax.axvline(best_gamma, linestyle="--", color="gray", alpha=0.5)
        ax.annotate(
            f"Best: $\\gamma$={best_gamma:.2g}",
            xy=(best_gamma, best_value),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=10,
        )

    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved gamma sweep curve to {output_path}")

    return fig


def plot_gamma_dual_axis(
    gammas: np.ndarray,
    y1_values: np.ndarray,
    y2_values: np.ndarray,
    y1_label: str = "HVI",
    y2_label: str = "Validity",
    y1_errors: np.ndarray | None = None,
    y2_errors: np.ndarray | None = None,
    output_path: str | Path | None = None,
    title: str = "Gamma Sweep: HVI vs Quality Metric",
    figsize: tuple[float, float] = (10, 6),
    log_scale: bool = True,
    y1_color: str = "#1f77b4",
    y2_color: str = "#d62728",
) -> plt.Figure:
    """Create dual-axis plot showing HVI vs another metric across gamma values.

    Args:
        gammas: Array of gamma values
        y1_values: Primary y-axis values (typically HVI)
        y2_values: Secondary y-axis values (quality metric)
        y1_label: Label for primary y-axis
        y2_label: Label for secondary y-axis
        y1_errors: Optional error bars for primary axis
        y2_errors: Optional error bars for secondary axis
        output_path: Path to save figure
        title: Plot title
        figsize: Figure size in inches
        log_scale: Whether to use log scale for x-axis
        y1_color: Color for primary axis
        y2_color: Color for secondary axis

    Returns:
        matplotlib Figure
    """
    setup_style()

    fig, ax1 = plt.subplots(figsize=figsize)

    # Primary y-axis
    ax1.set_xlabel("Gamma ($\\gamma$)", fontsize=12)
    ax1.set_ylabel(y1_label, fontsize=12, color=y1_color)

    if y1_errors is not None:
        line1 = ax1.errorbar(
            gammas,
            y1_values,
            yerr=y1_errors,
            fmt="o-",
            capsize=5,
            capthick=2,
            linewidth=2,
            markersize=8,
            color=y1_color,
            label=y1_label,
        )
    else:
        (line1,) = ax1.plot(
            gammas, y1_values, "o-", linewidth=2, markersize=8, color=y1_color, label=y1_label
        )

    ax1.tick_params(axis="y", labelcolor=y1_color)
    if log_scale:
        ax1.set_xscale("log")
    ax1.grid(True, alpha=0.3, linestyle="--")

    # Secondary y-axis
    ax2 = ax1.twinx()
    ax2.set_ylabel(y2_label, fontsize=12, color=y2_color)

    if y2_errors is not None:
        line2 = ax2.errorbar(
            gammas,
            y2_values,
            yerr=y2_errors,
            fmt="s--",
            capsize=5,
            capthick=2,
            linewidth=2,
            markersize=8,
            color=y2_color,
            label=y2_label,
        )
    else:
        (line2,) = ax2.plot(
            gammas, y2_values, "s--", linewidth=2, markersize=8, color=y2_color, label=y2_label
        )

    ax2.tick_params(axis="y", labelcolor=y2_color)

    # Combined legend
    lines = [line1, line2]
    labels = [y1_label, y2_label]
    ax1.legend(lines, labels, loc="best")

    plt.title(title, fontsize=14, fontweight="bold")
    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved gamma dual-axis plot to {output_path}")

    return fig


def plot_gamma_heatmap(
    gamma_seed_values: np.ndarray,
    gammas: Sequence[float],
    seeds: Sequence[int],
    output_path: str | Path | None = None,
    title: str = "Penalized HVI: Gamma vs Seed",
    figsize: tuple[float, float] = (10, 6),
    cmap: str = "viridis",
) -> plt.Figure:
    """Plot heatmap of metric values across gamma and seed combinations.

    Args:
        gamma_seed_values: 2D array of shape (n_gammas, n_seeds)
        gammas: List of gamma values
        seeds: List of seed values
        output_path: Path to save figure
        title: Plot title
        figsize: Figure size in inches
        cmap: Colormap name

    Returns:
        matplotlib Figure
    """
    setup_style()

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(gamma_seed_values, aspect="auto", cmap=cmap)

    ax.set_xticks(range(len(seeds)))
    ax.set_xticklabels([f"seed={s}" for s in seeds])
    ax.set_yticks(range(len(gammas)))
    ax.set_yticklabels([f"{g:.4g}" for g in gammas])

    ax.set_xlabel("Seed", fontsize=12)
    ax.set_ylabel("Gamma", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight="bold")

    # Add value annotations
    for i in range(len(gammas)):
        for j in range(len(seeds)):
            val = gamma_seed_values[i, j]
            if not np.isnan(val):
                text_color = (
                    "white"
                    if val < (gamma_seed_values.max() + gamma_seed_values.min()) / 2
                    else "black"
                )
                ax.text(j, i, f"{val:.3f}", ha="center", va="center", color=text_color, fontsize=8)

    plt.colorbar(im, ax=ax, label="Penalized HVI")
    plt.tight_layout()

    if output_path:
        fig.savefig(output_path)
        logger.info(f"Saved gamma heatmap to {output_path}")

    return fig


def plot_gamma_metrics_grid(
    gammas: np.ndarray,
    metrics_dict: dict[str, tuple[np.ndarray, np.ndarray | None]],
    output_path: str | Path | None = None,
    title: str = "Gamma Sweep: All Metrics",
    figsize: tuple[float, float] | None = None,
    log_scale: bool = True,
) -> plt.Figure:
    """Plot grid of all metrics vs gamma.

    Args:
        gammas: Array of gamma values
        metrics_dict: Dict mapping metric name to (values, errors) tuple
        output_path: Path to save figure
        title: Overall title
        figsize: Figure size (auto if None)
        log_scale: Whether to use log scale for x-axis

    Returns:
        matplotlib Figure
    """
    setup_style()

    n_metrics = len(metrics_dict)
    n_cols = min(3, n_metrics)
    n_rows = (n_metrics + n_cols - 1) // n_cols

    if figsize is None:
        figsize = (5 * n_cols, 4 * n_rows)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    colors = plt.cm.tab10.colors

    for i, (metric_name, (values, errors)) in enumerate(metrics_dict.items()):
        ax = axes[i]
        color = colors[i % len(colors)]

        if errors is not None:
            ax.errorbar(
                gammas,
                values,
                yerr=errors,
                fmt="o-",
                capsize=4,
                linewidth=1.5,
                markersize=6,
                color=color,
            )
        else:
            ax.plot(gammas, values, "o-", linewidth=1.5, markersize=6, color=color)

        if log_scale:
            ax.set_xscale("log")

        ax.set_xlabel("Gamma ($\\gamma$)", fontsize=10)
        ax.set_ylabel(metric_name.replace("_", " ").title(), fontsize=10)
        ax.set_title(metric_name.replace("_", " ").title(), fontsize=11)
        ax.grid(True, alpha=0.3, linestyle="--")

    # Hide unused axes
    for i in range(n_metrics, len(axes)):
        axes[i].set_visible(False)

    fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
    plt.tight_layout()

    if output_path:
        fig.savefig(output_path, bbox_inches="tight")
        logger.info(f"Saved gamma metrics grid to {output_path}")

    return fig
