"""Plotting utilities for training visualization.

This module provides functions for creating training curve plots
and other visualizations used in experiment analysis.
"""

from __future__ import annotations

import os
from typing import Any, Dict, List, Sequence

import matplotlib.pyplot as plt
import numpy as np


def plot_training_curves(
    history: List[Dict[str, Any]],
    save_path: str,
    title: str = "Training Curves",
    train_key: str = "train_loss",
    val_key: str = "val_loss",
    figsize: tuple[int, int] = (10, 6),
) -> None:
    """Plot training and validation loss curves.

    Args:
        history: List of dicts with epoch metrics, each containing at least
            'epoch', train_key, and val_key
        save_path: Path to save the figure
        title: Plot title
        train_key: Key for training loss in history dicts
        val_key: Key for validation loss in history dicts
        figsize: Figure size (width, height)
    """
    if not history:
        # Create empty plot for edge case
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title(f"{title} (no data)")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
        return

    epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
    train_losses = [h.get(train_key) for h in history]
    val_losses = [h.get(val_key) for h in history]

    fig, ax = plt.subplots(figsize=figsize)

    # Filter out None values
    train_valid = [(epoch, loss) for epoch, loss in zip(epochs, train_losses) if loss is not None]
    val_valid = [(epoch, loss) for epoch, loss in zip(epochs, val_losses) if loss is not None]

    if train_valid:
        train_epochs, train_vals = zip(*train_valid)
        ax.plot(train_epochs, train_vals, "b-", label="Train", linewidth=2)

    if val_valid:
        val_epochs, val_vals = zip(*val_valid)
        ax.plot(val_epochs, val_vals, "r-", label="Validation", linewidth=2)

    ax.set_xlabel("Epoch", fontsize=12)
    ax.set_ylabel("Loss", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    # Use log scale if values span multiple orders of magnitude
    if train_valid or val_valid:
        all_vals = [v for _, v in train_valid + val_valid]
        if len(all_vals) > 1 and max(all_vals) / (min(all_vals) + 1e-10) > 100:
            ax.set_yscale("log")

    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_multi_stage_curves(
    histories: Dict[str, List[Dict[str, Any]]],
    save_path: str,
    title: str = "Multi-Stage Training Curves",
    figsize: tuple[int, int] = (14, 5),
) -> None:
    """Plot training curves for multiple pipeline stages side by side.

    Args:
        histories: Dict mapping stage names to their training histories
        save_path: Path to save the figure
        title: Overall figure title
        figsize: Figure size (width, height)
    """
    stages = list(histories.keys())
    n_stages = len(stages)

    if n_stages == 0:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.set_title(f"{title} (no data)")
        ax.text(0.5, 0.5, "No training data available", ha="center", va="center")
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
        return

    fig, axes = plt.subplots(1, n_stages, figsize=figsize)
    if n_stages == 1:
        axes = [axes]

    for ax, stage in zip(axes, stages):
        history = histories[stage]

        if not history:
            ax.set_title(f"{stage} (no data)")
            ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
            continue

        epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]

        # Try multiple common key patterns
        train_keys = ["train_loss", "loss"]
        val_keys = ["val_loss", "val_mse", "mse"]

        train_losses = None
        val_losses = None

        for key in train_keys:
            if history[0].get(key) is not None:
                train_losses = [h.get(key) for h in history]
                break

        for key in val_keys:
            if history[0].get(key) is not None:
                val_losses = [h.get(key) for h in history]
                break

        if train_losses:
            valid_data = [
                (epoch, loss) for epoch, loss in zip(epochs, train_losses) if loss is not None
            ]
            if valid_data:
                epochs, losses = zip(*valid_data)
                ax.plot(epochs, losses, "b-", label="Train", linewidth=2)

        if val_losses:
            valid_data = [
                (epoch, loss) for epoch, loss in zip(epochs, val_losses) if loss is not None
            ]
            if valid_data:
                epochs, losses = zip(*valid_data)
                ax.plot(epochs, losses, "r-", label="Val", linewidth=2)

        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_title(stage.capitalize())
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_property_scatter(
    actual: np.ndarray,
    predicted: np.ndarray,
    property_names: Sequence[str],
    save_path: str,
    title: str = "Property Prediction",
    figsize: tuple[int, int] | None = None,
) -> None:
    """Plot actual vs predicted property scatter plots.

    Args:
        actual: Array of actual property values (N, n_props)
        predicted: Array of predicted property values (N, n_props)
        property_names: Names of the properties
        save_path: Path to save the figure
        title: Overall figure title
        figsize: Figure size (auto-calculated if None)
    """
    n_props = len(property_names)
    if figsize is None:
        figsize = (5 * n_props, 5)

    fig, axes = plt.subplots(1, n_props, figsize=figsize)
    if n_props == 1:
        axes = [axes]

    for i, (ax, name) in enumerate(zip(axes, property_names)):
        y_true = actual[:, i] if actual.ndim > 1 else actual
        y_pred = predicted[:, i] if predicted.ndim > 1 else predicted

        # Filter out NaN values
        mask = ~(np.isnan(y_true) | np.isnan(y_pred))
        y_true = y_true[mask]
        y_pred = y_pred[mask]

        if len(y_true) == 0:
            ax.text(0.5, 0.5, "No valid data", ha="center", va="center", transform=ax.transAxes)
            ax.set_title(name)
            continue

        ax.scatter(y_true, y_pred, alpha=0.5, s=20)

        # Add diagonal line
        lims = [
            min(y_true.min(), y_pred.min()),
            max(y_true.max(), y_pred.max()),
        ]
        ax.plot(lims, lims, "k--", alpha=0.5, linewidth=1)

        # Calculate R^2
        ss_res = np.sum((y_true - y_pred) ** 2)
        ss_tot = np.sum((y_true - y_true.mean()) ** 2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0

        ax.set_xlabel("Actual")
        ax.set_ylabel("Predicted")
        ax.set_title(f"{name} (R2={r2:.3f})")
        ax.grid(True, alpha=0.3)

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_pareto_optimization(
    candidates: np.ndarray,
    optimized: np.ndarray | None,
    pareto_mask: np.ndarray,
    prop_names: Sequence[str],
    save_path: str,
    prop_indices: tuple[int, int] = (0, 1),
    title: str = "Pareto Front Optimization",
    figsize: tuple[int, int] = (10, 8),
    show_arrows: bool = True,
    arrow_alpha: float = 0.6,
    near_pareto_mask: np.ndarray | None = None,
    selection_mask: np.ndarray | None = None,
    valid_mask: np.ndarray | None = None,
) -> None:
    """Plot Pareto front with optimization arrows.

    Creates a scatter plot showing:
    - All candidate points (light gray)
    - Near-Pareto points (orange squares, if provided)
    - True Pareto front highlighted (red circles)
    - Arrows from selected candidates to optimized outputs (blue for valid, red for invalid)
    - Valid optimized molecules (blue circles) vs invalid ones (red X's)

    Args:
        candidates: Array of shape (N, D) with candidate property values
        optimized: Array of shape (N, D) or (M, D) with optimized property values.
                  If same shape as candidates, arrows are drawn pairwise.
                  If None, no optimization arrows are drawn.
        pareto_mask: Boolean mask of shape (N,) indicating true Pareto-optimal candidates
        prop_names: List of property names (at least 2)
        save_path: Path to save the figure
        prop_indices: Tuple of (x_index, y_index) for which properties to plot
        title: Plot title
        figsize: Figure size
        show_arrows: Whether to draw optimization arrows
        arrow_alpha: Transparency of arrows
        near_pareto_mask: Optional boolean mask for near-Pareto points (K-neighbors).
                         These are shown with different markers than true Pareto.
        selection_mask: Optional boolean mask for all selected points (Pareto + neighbors).
                       Used to determine which points get optimization arrows.
                       If None, defaults to pareto_mask.
        valid_mask: Optional boolean mask indicating which optimized molecules are valid (RDKit).
                   If provided, valid (blue circles) and invalid (red X's) molecules are
                   distinguished in the plot.

    Example:
        >>> candidates = np.random.randn(100, 3)  # 100 samples, 3 properties
        >>> optimized = candidates + np.random.randn(100, 3) * 0.1
        >>> pareto_mask = pareto_front(candidates[:, :2], sense=["max", "min"])
        >>> plot_pareto_optimization(
        ...     candidates, optimized, pareto_mask,
        ...     prop_names=["QED", "SAS", "pLogP"],
        ...     save_path="pareto_opt.png",
        ...     prop_indices=(0, 1)  # Plot QED vs SAS
        ... )
    """
    x_idx, y_idx = prop_indices

    if len(prop_names) < 2:
        raise ValueError("At least 2 property names required")

    x_name = prop_names[x_idx] if x_idx < len(prop_names) else f"Property {x_idx}"
    y_name = prop_names[y_idx] if y_idx < len(prop_names) else f"Property {y_idx}"

    # Default selection_mask to pareto_mask if not provided
    if selection_mask is None:
        selection_mask = pareto_mask

    fig, ax = plt.subplots(figsize=figsize)

    # Extract x and y coordinates
    cand_x = candidates[:, x_idx]
    cand_y = candidates[:, y_idx]

    # Determine which points are "other" (not Pareto, not near-Pareto)
    if near_pareto_mask is not None:
        other_mask = ~pareto_mask & ~near_pareto_mask
    else:
        other_mask = ~pareto_mask

    # Plot other candidates (light gray, background)
    if other_mask.any():
        ax.scatter(
            cand_x[other_mask],
            cand_y[other_mask],
            c="lightgray",
            alpha=0.5,
            s=30,
            label="Non-Pareto",
            zorder=1,
        )

    # Plot near-Pareto points (orange squares) if provided
    if near_pareto_mask is not None and near_pareto_mask.any():
        ax.scatter(
            cand_x[near_pareto_mask],
            cand_y[near_pareto_mask],
            c="orange",
            s=70,
            marker="s",
            edgecolors="darkorange",
            linewidths=1.0,
            label=f"Near-Pareto (n={near_pareto_mask.sum()})",
            zorder=2,
        )

    # Plot true Pareto front (red circles, highlighted)
    ax.scatter(
        cand_x[pareto_mask],
        cand_y[pareto_mask],
        c="red",
        s=100,
        marker="o",
        edgecolors="darkred",
        linewidths=1.5,
        label=f"Pareto Front (n={pareto_mask.sum()})",
        zorder=3,
    )

    # Draw optimization arrows
    if show_arrows and optimized is not None:
        opt_x = optimized[:, x_idx]
        opt_y = optimized[:, y_idx]

        # If optimized has same shape as candidates, draw pairwise arrows
        if optimized.shape[0] == candidates.shape[0]:
            # Draw arrows for all selected candidates (Pareto + near-Pareto)
            selected_indices = np.where(selection_mask)[0]
            for i in selected_indices:
                dx = opt_x[i] - cand_x[i]
                dy = opt_y[i] - cand_y[i]

                # Only draw arrow if there's meaningful movement
                if abs(dx) > 1e-6 or abs(dy) > 1e-6:
                    # Determine arrow color based on validity
                    is_valid = valid_mask[i] if valid_mask is not None else True
                    if pareto_mask[i]:
                        arrow_color = "blue" if is_valid else "red"
                    else:
                        arrow_color = "teal" if is_valid else "orange"

                    ax.annotate(
                        "",
                        xy=(opt_x[i], opt_y[i]),
                        xytext=(cand_x[i], cand_y[i]),
                        arrowprops=dict(
                            arrowstyle="->",
                            color=arrow_color,
                            alpha=arrow_alpha,
                            linewidth=1.5,
                        ),
                        zorder=2,
                    )

            # Plot optimized points - separate valid from invalid
            if valid_mask is not None:
                valid_mask_sel = valid_mask[selection_mask]
                # Valid optimized molecules (blue circles)
                if valid_mask_sel.any():
                    ax.scatter(
                        opt_x[selection_mask][valid_mask_sel],
                        opt_y[selection_mask][valid_mask_sel],
                        c="blue",
                        s=70,
                        marker="o",
                        edgecolors="darkblue",
                        linewidths=1,
                        label=f"Optimized Valid (n={valid_mask_sel.sum()})",
                        zorder=4,
                    )
                # Invalid optimized molecules (red X's)
                if (~valid_mask_sel).any():
                    ax.scatter(
                        opt_x[selection_mask][~valid_mask_sel],
                        opt_y[selection_mask][~valid_mask_sel],
                        c="red",
                        s=100,
                        marker="x",
                        linewidths=2.5,
                        label=f"Optimized Invalid (n={(~valid_mask_sel).sum()})",
                        zorder=4,
                    )
            else:
                # Original behavior when no valid_mask provided
                ax.scatter(
                    opt_x[selection_mask],
                    opt_y[selection_mask],
                    c="blue",
                    s=70,
                    marker="o",
                    edgecolors="darkblue",
                    linewidths=1,
                    label="Optimized",
                    zorder=4,
                )
        else:
            # Different shape: just plot all optimized points
            if valid_mask is not None:
                # Valid optimized molecules (blue circles)
                if valid_mask.any():
                    ax.scatter(
                        opt_x[valid_mask],
                        opt_y[valid_mask],
                        c="blue",
                        s=70,
                        marker="o",
                        edgecolors="darkblue",
                        linewidths=1,
                        label=f"Optimized Valid (n={valid_mask.sum()})",
                        zorder=4,
                    )
                # Invalid optimized molecules (red X's)
                if (~valid_mask).any():
                    ax.scatter(
                        opt_x[~valid_mask],
                        opt_y[~valid_mask],
                        c="red",
                        s=100,
                        marker="x",
                        linewidths=2.5,
                        label=f"Optimized Invalid (n={(~valid_mask).sum()})",
                        zorder=4,
                    )
            else:
                ax.scatter(
                    opt_x,
                    opt_y,
                    c="blue",
                    s=70,
                    marker="o",
                    edgecolors="darkblue",
                    linewidths=1,
                    label="Optimized",
                    zorder=4,
                )

    ax.set_xlabel(x_name, fontsize=12)
    ax.set_ylabel(y_name, fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc="best", fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_pareto_front_2d(
    points: np.ndarray,
    pareto_mask: np.ndarray,
    prop_names: Sequence[str],
    save_path: str,
    prop_indices: tuple[int, int] = (0, 1),
    title: str = "Pareto Front",
    figsize: tuple[int, int] = (8, 6),
    highlight_color: str = "red",
    background_color: str = "lightgray",
) -> None:
    """Simple 2D Pareto front visualization without optimization arrows.

    Args:
        points: Array of shape (N, D) with objective values
        pareto_mask: Boolean mask indicating Pareto-optimal points
        prop_names: List of property names
        save_path: Path to save figure
        prop_indices: Which properties to plot on x and y axes
        title: Plot title
        figsize: Figure size
        highlight_color: Color for Pareto front points
        background_color: Color for non-Pareto points
    """
    x_idx, y_idx = prop_indices
    x_name = prop_names[x_idx] if x_idx < len(prop_names) else f"Property {x_idx}"
    y_name = prop_names[y_idx] if y_idx < len(prop_names) else f"Property {y_idx}"

    fig, ax = plt.subplots(figsize=figsize)

    x = points[:, x_idx]
    y = points[:, y_idx]

    # Non-Pareto points
    non_pareto = ~pareto_mask
    ax.scatter(
        x[non_pareto], y[non_pareto], c=background_color, alpha=0.5, s=30, label="Non-Pareto"
    )

    # Pareto front
    ax.scatter(
        x[pareto_mask],
        y[pareto_mask],
        c=highlight_color,
        s=80,
        marker="o",
        edgecolors="darkred",
        linewidths=1.5,
        label=f"Pareto Front (n={pareto_mask.sum()})",
    )

    # Sort Pareto points and draw connecting line
    pareto_pts = points[pareto_mask]
    if len(pareto_pts) > 1:
        # Sort by x coordinate for line
        sort_idx = np.argsort(pareto_pts[:, x_idx])
        ax.plot(
            pareto_pts[sort_idx, x_idx],
            pareto_pts[sort_idx, y_idx],
            c=highlight_color,
            alpha=0.5,
            linestyle="--",
            linewidth=1,
        )

    ax.set_xlabel(x_name, fontsize=12)
    ax.set_ylabel(y_name, fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc="best", fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_ablation_comparison(
    results: Dict[str, Dict[str, float]],
    metric_names: Sequence[str],
    save_path: str,
    title: str = "Ablation Study Comparison",
    figsize: tuple[int, int] | None = None,
    colors: Sequence[str] | None = None,
) -> None:
    """Plot bar chart comparing metrics across ablation methods.

    Args:
        results: Dict mapping method name -> dict of metric name -> value
            e.g., {"mean": {"R2": 0.9, "MSE": 0.1}, "attention": {"R2": 0.95, "MSE": 0.08}}
        metric_names: List of metric names to plot
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size (auto-calculated if None)
        colors: Optional list of colors for each method
    """
    methods = list(results.keys())
    n_methods = len(methods)
    n_metrics = len(metric_names)

    if figsize is None:
        figsize = (max(8, n_metrics * 2), 6)

    if colors is None:
        # Use a colorblind-friendly palette
        colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B3", "#CCB974", "#64B5CD"]

    fig, ax = plt.subplots(figsize=figsize)

    x = np.arange(n_metrics)
    width = 0.8 / n_methods  # Width of each bar

    for i, method in enumerate(methods):
        values = [results[method].get(m, 0) for m in metric_names]
        offset = (i - n_methods / 2 + 0.5) * width
        bars = ax.bar(
            x + offset,
            values,
            width,
            label=method,
            color=colors[i % len(colors)],
            edgecolor="black",
            linewidth=0.5,
        )
        # Add value labels on bars
        for bar, val in zip(bars, values):
            height = bar.get_height()
            ax.annotate(
                f"{val:.3f}",
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha="center",
                va="bottom",
                fontsize=8,
            )

    ax.set_xlabel("Metric", fontsize=12)
    ax.set_ylabel("Value", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(metric_names, fontsize=10)
    ax.legend(title="Method", fontsize=9)
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_prediction_scatter_comparison(
    method_results: Dict[str, tuple[np.ndarray, np.ndarray]],
    property_names: Sequence[str],
    save_path: str,
    title: str = "Prediction Comparison Across Methods",
    figsize: tuple[int, int] | None = None,
) -> None:
    """Plot actual vs predicted scatter plots comparing multiple methods.

    Creates a grid of scatter plots with methods as rows and properties as columns.

    Args:
        method_results: Dict mapping method name -> (y_true, y_pred) tuple
            Both arrays should have shape (N, n_props)
        property_names: Names of the properties
        save_path: Path to save the figure
        title: Overall figure title
        figsize: Figure size (auto-calculated if None)
    """
    from moltenflow.eval.prediction_metrics import compute_r2

    methods = list(method_results.keys())
    n_methods = len(methods)
    n_props = len(property_names)

    if figsize is None:
        figsize = (4 * n_props, 3.5 * n_methods)

    fig, axes = plt.subplots(n_methods, n_props, figsize=figsize, squeeze=False)

    for i, method in enumerate(methods):
        y_true, y_pred = method_results[method]

        for j, prop_name in enumerate(property_names):
            ax = axes[i, j]

            yt = y_true[:, j] if y_true.ndim > 1 else y_true
            yp = y_pred[:, j] if y_pred.ndim > 1 else y_pred

            # Filter NaN values
            mask = ~(np.isnan(yt) | np.isnan(yp))
            yt_valid = yt[mask]
            yp_valid = yp[mask]

            if len(yt_valid) == 0:
                ax.text(0.5, 0.5, "No valid data", ha="center", va="center", transform=ax.transAxes)
            else:
                ax.scatter(yt_valid, yp_valid, alpha=0.4, s=15)

                # Add diagonal line
                lims = [
                    min(yt_valid.min(), yp_valid.min()),
                    max(yt_valid.max(), yp_valid.max()),
                ]
                ax.plot(lims, lims, "k--", alpha=0.5, linewidth=1)

                # Calculate and display R^2
                r2 = compute_r2(yt_valid, yp_valid)
                ax.text(
                    0.05,
                    0.95,
                    f"R2={r2:.3f}",
                    transform=ax.transAxes,
                    fontsize=9,
                    verticalalignment="top",
                    bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
                )

            # Labels
            if i == n_methods - 1:
                ax.set_xlabel("Actual", fontsize=10)
            if j == 0:
                ax.set_ylabel(f"{method}\nPredicted", fontsize=10)

            if i == 0:
                ax.set_title(prop_name, fontsize=11)

            ax.grid(True, alpha=0.3)

    fig.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_ablation_summary_table(
    results: Dict[str, Dict[str, Any]],
    save_path: str,
    title: str = "Ablation Study Results",
    metric_columns: Sequence[str] | None = None,
    figsize: tuple[int, int] = (12, 6),
) -> None:
    """Create a table visualization of ablation study results.

    Args:
        results: Dict mapping method name -> dict of metrics
        save_path: Path to save the figure
        title: Table title
        metric_columns: Specific metrics to include (uses all if None)
        figsize: Figure size
    """
    methods = list(results.keys())

    # Determine columns
    if metric_columns is None:
        # Gather all unique metric keys
        all_keys = set()
        for m_results in results.values():
            all_keys.update(m_results.keys())
        metric_columns = sorted(all_keys)

    # Build table data
    table_data = []
    for method in methods:
        row = [method]
        for col in metric_columns:
            val = results[method].get(col, "-")
            if isinstance(val, float):
                row.append(f"{val:.4f}")
            else:
                row.append(str(val))
        table_data.append(row)

    fig, ax = plt.subplots(figsize=figsize)
    ax.axis("off")

    # Create table
    col_labels = ["Method"] + list(metric_columns)
    table = ax.table(
        cellText=table_data,
        colLabels=col_labels,
        loc="center",
        cellLoc="center",
    )
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.5)

    # Style header row
    for j in range(len(col_labels)):
        cell = table[(0, j)]
        cell.set_text_props(weight="bold")
        cell.set_facecolor("#E6E6E6")

    ax.set_title(title, fontsize=14, pad=20)

    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
