from io import StringIO

import matplotlib.colors
import matplotlib.font_manager as fm
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pypalettes import load_cmap
from scipy import stats

# Set global font to Roboto
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["Roboto"]
# In case Roboto is not available, try to download it
try:
    import os

    import matplotlib as mpl
    from matplotlib import font_manager

    # Check if Roboto is already available
    roboto_available = any("Roboto" in f.name for f in font_manager.fontManager.ttflist)

    if not roboto_available:
        # Try to download Roboto
        import tempfile
        import urllib.request
        import zipfile

        # URLs for Roboto font
        roboto_url = "https://fonts.google.com/download?family=Roboto"

        # Create temp directory
        with tempfile.TemporaryDirectory() as tmpdirname:
            zip_path = os.path.join(tmpdirname, "roboto.zip")
            # Download the font
            urllib.request.urlretrieve(roboto_url, zip_path)

            # Extract the font
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(tmpdirname)

            # Find font files
            font_files = []
            for root, dirs, files in os.walk(tmpdirname):
                for file in files:
                    if file.endswith(".ttf"):
                        font_path = os.path.join(root, file)
                        font_manager.fontManager.addfont(font_path)
                        font_files.append(font_path)

            # Update font cache
            font_manager._rebuild()

            print(f"Added {len(font_files)} Roboto font files")
except Exception as e:
    print(f"Could not download Roboto font: {e}")
    print("Using default sans-serif font instead")
    plt.rcParams["font.sans-serif"] = [
        "DejaVu Sans",
        "Arial",
        "Helvetica",
        "sans-serif",
    ]

# Configure plot style to remove top and left spines (like seaborn style)
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.linewidth"] = 1.0
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
plt.rcParams["grid.linestyle"] = "--"
# Remove legend frames by default
plt.rcParams["legend.frameon"] = False


# Set a function to apply the spine style to any axis
def apply_spine_style(ax):
    """Apply the clean spine style to an axis."""
    # Remove the top and right spines
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    # Make the remaining spines slightly thinner
    ax.spines["bottom"].set_linewidth(1.0)
    ax.spines["left"].set_linewidth(1.0)
    return ax


# Create a custom palette function that returns a list of colors
def get_palette(n):
    """
    Get a palette of n colors from the custom color map.

    Args:
        n: Number of colors to return

    Returns:
        List of colors
    """
    if n <= 0:
        return []

    # Get the colormap function
    # cmap = load_cmap(["basel", "Fun"])
    # cmap = plt.get_cmap("tab20")
    # rainbow
    cmap = plt.get_cmap("rainbow")
    # Generate evenly spaced values between 0 and 1
    values = np.linspace(0, 1, n)
    # Get colors for each value and convert to RGBA tuples
    colors = [tuple(cmap(v)) for v in values]
    return colors


def adjusted_corrcoef(x, y, epsilon=1e-8):
    """Calculate correlation coefficient with numerical stability."""
    x = np.asarray(x)
    y = np.asarray(y)

    mean_x = np.mean(x)
    mean_y = np.mean(y)

    cov = np.sum((x - mean_x) * (y - mean_y))
    std_x = np.sqrt(np.sum((x - mean_x) ** 2) + epsilon)
    std_y = np.sqrt(np.sum((y - mean_y) ** 2) + epsilon)

    return cov / (std_x * std_y)


def prepare_data(data, mmlu_mapping=None, delimiter="\t"):
    """
    Prepare data for plotting by cleaning and standardizing column names,
    converting numeric columns, and calculating capability differences.

    Args:
        data: String data or path to a CSV file
        mmlu_mapping: Dictionary mapping model keys to their tinyMMLU scores
        delimiter: Delimiter used in the data file (default: tab)

    Returns:
        DataFrame with cleaned and processed data
    """
    # Read the data
    if isinstance(data, str):
        if data.strip().startswith(("target_model", "target model")):
            # Data is a string containing the actual data
            df = pd.read_csv(StringIO(data), sep=delimiter, skip_blank_lines=True)
        else:
            # Data is a file path
            df = pd.read_csv(data, sep=delimiter, skip_blank_lines=True)
    else:
        # Data is already a DataFrame
        df = data.copy()

    # Clean and standardize column names
    df.columns = df.columns.str.strip()

    # Ensure we have the expected columns or rename them
    expected_columns = [
        "target_model_key",
        "attacker_model_key",
        "ASR",
        "judge_correlation",
        "total_behaviors",
    ]
    if len(df.columns) >= len(expected_columns):
        df.columns = expected_columns + list(df.columns[len(expected_columns) :])

    # Remove extra spaces from string columns
    for col in df.columns:
        if df[col].dtype == "object":
            df[col] = df[col].str.strip()

    # Add tinyMMLU scores if mapping is provided
    if mmlu_mapping is not None:
        for key in mmlu_mapping.keys():
            for val in mmlu_mapping[key].keys():
                if "target" in val:
                    df.loc[df["target_model_key"] == key, val] = mmlu_mapping[key][val]
                elif "attacker" in val:
                    df.loc[df["attacker_model_key"] == key, val] = mmlu_mapping[key][
                        val
                    ]

    # Convert numeric columns (errors='coerce' will convert missing or invalid values to NaN)
    numeric_columns = ["ASR", "judge_correlation"]
    for col in numeric_columns:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")

    # Create label column for plotting
    df["label"] = df["attacker_model_key"] + "→" + df["target_model_key"]

    df = df[df["total_behaviors"] >= 40]

    return df


def get_model_colors(df, model_column="target_model_key"):
    """
    Get a consistent color mapping for all unique models in the dataframe.

    Args:
        df: DataFrame containing model names
        model_column: Column name containing model identifiers

    Returns:
        Dictionary mapping model names to colors
    """
    # Get all unique models
    all_models = sorted(df[model_column].unique())

    # Get the colormap function - use the same one as defined at the top
    cmap = load_cmap(["basel", "Fun"])
    # Generate evenly spaced values between 0 and 1
    values = np.linspace(0, 1, len(all_models))
    # Get colors for each value and convert to RGBA tuples
    colors = [tuple(cmap(v)) for v in values]

    # Create mapping
    return dict(zip(all_models, colors))


def plot_capability_diff_vs_asr(
    df,
    columns_for_diff,
    diff_column=None,
    title="Capability Diff. vs ASR",
    show_annotations=False,
    fit_lines=True,
    save_path=None,
    metric_name="MMLU-Pro",
    xlabel=None,
):
    assert len(columns_for_diff) == 2
    assert columns_for_diff[0] != columns_for_diff[1]
    assert "attacker" in columns_for_diff[0], "attacker must be in the first column"
    assert "target" in columns_for_diff[1], "target must be in the second column"

    if diff_column is None:
        df["capability_diff"] = df[columns_for_diff[0]] - df[columns_for_diff[1]]
    else:
        df["capability_diff"] = df[diff_column]
        print(f"ATTENTION: Using {diff_column} as the capability diff column")

    # Create label if not present
    if "label" not in df.columns:
        df = df.copy()
        df["label"] = df["attacker_model_key"] + "→" + df["target_model_key"]

    # Get unique targets and sort by score
    unique_targets = df["target_model_key"].unique()
    target_mmlu = df.groupby("target_model_key")[columns_for_diff[1]].mean()
    sorted_targets = sorted(unique_targets, key=lambda t: target_mmlu[t])

    model_colors = get_palette(len(sorted_targets))
    model_colors = dict(zip(sorted_targets, model_colors))

    # Create the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    apply_spine_style(ax)  # Apply our spine style
    ax.set_axisbelow(True)
    ax.grid(True, linestyle="-", color="gray", alpha=0.7, axis="y")
    per_target_rmse = []
    per_target_r2 = []
    for target in sorted_targets:
        subset = df[df["target_model_key"] == target]
        if len(subset) > 1:
            # Clean data for correlation calculation
            clean_subset = subset.dropna(subset=["capability_diff", "ASR"])
            x = clean_subset["capability_diff"].values
            y = clean_subset["ASR"].values
            if len(x) > 1:
                # Calculate R² using linregress for consistency
                slope, intercept, r_value, _, _ = stats.linregress(x, y)
                # calucalte RMSE
                rmse = np.sqrt(np.mean((y - slope * x - intercept) ** 2))
                per_target_rmse.append(rmse)
                per_target_r2.append(r_value**2)
                label_text = f"{target} (R² = {r_value**2:.2f}, RMSE = {rmse:.2f})"
            else:
                label_text = f"{target} (n=1)"
        else:
            label_text = f"{target} (n=1)"

        ax.scatter(
            subset["capability_diff"],
            subset["ASR"],
            s=100,
            color=model_colors[target],
            edgecolors="black",
            linewidth=0.8,
            zorder=5,
            label=label_text,
        )

        # Add annotations if requested
        if show_annotations:
            for i, row in subset.iterrows():
                ax.annotate(
                    row["label"],
                    (row["capability_diff"], row["ASR"]),
                    textcoords="offset points",
                    xytext=(-45, -5),
                    ha="left",
                    fontsize=8,
                    rotation=15,
                    bbox=dict(boxstyle="round,pad=0.2", fc=(1, 1, 1, 0.7), ec="none"),
                    zorder=10,
                )

    # Fit regression lines if requested
    if fit_lines:
        x_max = df["capability_diff"].max()
        x_min = df["capability_diff"].min()
        y_max = df["ASR"].max()
        y_min = df["ASR"].min()

        for target in sorted_targets:
            try:
                subset = df[df["target_model_key"] == target]
                if len(subset) > 1:  # Only fit if we have more than one point
                    # Clean data - remove any NaN values
                    local_subset = subset.dropna(subset=["capability_diff", "ASR"])
                    x = local_subset["capability_diff"].values
                    y = local_subset["ASR"].values

                    # Only proceed if we have valid data points
                    if len(x) > 1 and len(y) > 1:
                        # Perform linear regression
                        slope, intercept, r_value, p_value, std_err = stats.linregress(
                            x, y
                        )

                        # Generate points for the line
                        x_fit = np.linspace(x_min, x_max, 100)
                        y_fit = slope * x_fit + intercept

                        # Plot only if the fit makes sense (R² > 0.1)
                        if r_value**2 > 0.1:
                            # Plot line without legend
                            ax.plot(
                                x_fit,
                                y_fit,
                                color=model_colors[target],
                                linestyle="--",
                                linewidth=1,
                                alpha=0.5,
                            )
                        else:
                            print(
                                f"Skipping line for {target} due to poor fit (R² = {r_value**2:.2f})"
                            )
                    else:
                        print(f"Not enough valid data points for {target}")
            except Exception as e:
                print(f"Error fitting line for {target}: {str(e)}")

    # Set plot labels and limits
    ax.set_title(title, fontsize=14, fontweight="bold")
    if xlabel is None:
        xlabel = f"{metric_name} absolute difference (Attacker - Target)"
    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel("ASR", fontsize=12)

    # Add reference lines
    ax.axvline(x=0, color="black", linestyle="-", linewidth=0.8, alpha=0.5)
    ax.axhline(y=0.5, color="black", linestyle="-", linewidth=0.8, alpha=0.5)

    # Set axis limits with padding
    x_min, x_max = df["capability_diff"].min(), df["capability_diff"].max()
    x_offset = (x_max - x_min) * 0.05 if x_max != x_min else 5
    ax.set_xlim(x_min - x_offset, x_max + x_offset)
    ax.set_ylim(-0.05, 1.05)

    # Add legend
    fig.legend(
        title="Target Model", loc="upper left", bbox_to_anchor=(0.85, 1), frameon=False
    )

    plt.tight_layout(rect=[0, 0, 0.85, 1])

    # Save the figure if a path is provided
    if save_path:
        plt.savefig(save_path)

    print(f"RMSE: {np.mean(per_target_rmse)}")
    print(f"R²: {np.mean(per_target_r2)}")
    return fig, ax


def plot_judge_correlation_vs_asr(
    df,
    title="Judge Correlation vs ASR",
    show_annotations=False,
    fit_lines=False,
    save_path=None,
    x_limits=(-0.05, 1.05),
    metric_column_name="tinyMMLU_attacker",
    show=None,
):
    # Create a copy of the dataframe to avoid modifying the original
    df = df.copy()

    # Create label if not present
    if "label" not in df.columns:
        df["label"] = df["attacker_model_key"] + "→" + df["target_model_key"]

    # Get unique attackers and sort by tinyMMLU score
    unique_attackers = df["attacker_model_key"].unique()
    attacker_mmlu = df.groupby("attacker_model_key")[metric_column_name].mean()
    sorted_attackers = sorted(unique_attackers, key=lambda a: attacker_mmlu[a])

    # Get colors for each attacker model
    attacker_colors_list = get_palette(len(sorted_attackers))
    attacker_colors = dict(zip(sorted_attackers, attacker_colors_list))

    # Filter out zero correlation values if present
    if "judge_correlation" in df.columns:
        local_df = df[df["judge_correlation"] != 0]
    else:
        local_df = df

    # Create the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    apply_spine_style(ax)  # Apply our spine style
    ax.set_axisbelow(True)
    ax.grid(True, linestyle="--", color="gray", alpha=0.7)

    if show:
        # Calculate mean correlation and ASR for each attacker model
        mean_df = (
            local_df.groupby("attacker_model_key")
            .agg({"judge_correlation": show, "ASR": show})
            .reset_index()
        )

        # Plot one point per model with mean values
        for _, row in mean_df.iterrows():
            attacker = row["attacker_model_key"]
            ax.scatter(
                row["judge_correlation"],
                row["ASR"],
                s=100,
                color=attacker_colors[attacker],
                edgecolors="black",
                linewidth=0.8,
                zorder=5,
                label=f"{attacker}",
            )

            # Add annotations if requested
            if show_annotations:
                ax.annotate(
                    attacker,
                    (row["judge_correlation"], row["ASR"]),
                    textcoords="offset points",
                    xytext=(-45, -5),
                    ha="left",
                    fontsize=8,
                    rotation=15,
                    bbox=dict(boxstyle="round,pad=0.2", fc=(1, 1, 1, 0.7), ec="none"),
                    zorder=10,
                )

        # Fit regression line if requested
        if fit_lines and len(mean_df) > 1:
            x = mean_df["judge_correlation"]
            y = mean_df["ASR"]
            coeffs = np.polyfit(x, y, 1)
            x_fit = np.linspace(x.min(), x.max(), 100)
            y_fit = coeffs[0] * x_fit + coeffs[1]
            ax.plot(
                x_fit,
                y_fit,
                color="black",
                linestyle="--",
                linewidth=1,
                alpha=0.5,
                label=f"Fit (slope={coeffs[0]:.2f})",
            )
    else:
        for attacker in sorted_attackers:
            subset = local_df[local_df["attacker_model_key"] == attacker]
            if subset.empty:
                continue

            if len(subset) > 1 and fit_lines:
                corr_val = adjusted_corrcoef(subset["judge_correlation"], subset["ASR"])
                label_text = f"{attacker} (r = {corr_val:.2f})"
            else:
                label_text = attacker

            ax.scatter(
                subset["judge_correlation"],
                subset["ASR"],
                s=100,
                color=attacker_colors[attacker],
                edgecolors="black",
                linewidth=0.8,
                zorder=5,
                label=label_text,
            )

            # Add annotations if requested
            if show_annotations:
                for i, row in subset.iterrows():
                    ax.annotate(
                        row["label"],
                        (row["judge_correlation"], row["ASR"]),
                        textcoords="offset points",
                        xytext=(-45, -5),
                        ha="left",
                        fontsize=8,
                        rotation=15,
                        bbox=dict(
                            boxstyle="round,pad=0.2", fc=(1, 1, 1, 0.7), ec="none"
                        ),
                        zorder=10,
                    )

            # Fit regression lines if requested
            if fit_lines and len(subset) > 1:
                x = subset["judge_correlation"]
                y = subset["ASR"]
                coeffs = np.polyfit(x, y, 1)
                x_fit = np.linspace(x.min(), x.max(), 100)
                y_fit = coeffs[0] * x_fit + coeffs[1]
                ax.plot(
                    x_fit,
                    y_fit,
                    color=attacker_colors[attacker],
                    linestyle="--",
                    linewidth=1,
                    alpha=0.5,
                )

    # add a line from 0 0 to 1 1
    ax.plot([0, 1], [0, 1], color="black", linestyle="--", linewidth=0.8, alpha=0.5)
    # Set plot labels and limits
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.set_xlabel("PAIR score to HarmBench Judge Correlation", fontsize=12)
    ax.set_ylabel("ASR", fontsize=12)

    # Set x-axis limits if provided
    if x_limits:
        ax.set_xlim(x_limits)

    fig.legend(
        title=f"Attacker-Judge Model (agg. {show})",
        loc="upper left",
        bbox_to_anchor=(0.85, 1),
        frameon=False,
    )

    plt.tight_layout(rect=[0, 0, 0.85, 1])

    # Save the figure if a path is provided
    if save_path:
        plt.savefig(save_path)

    return fig, ax


def print_weak_strong_stats(
    df, columns_for_diff=["tinyMMLU_attacker", "tinyMMLU_target"]
):
    """
    Print statistics about weak-to-strong and strong-to-weak attacks.

    Args:
        df: DataFrame with capability_diff and ASR columns
    """
    if "capability_diff" not in df.columns and columns_for_diff is not None:
        df = df.copy()
        df["capability_diff"] = df[columns_for_diff].apply(
            lambda x: (x[0] - x[1]) * 100, axis=1
        )
    elif "capability_diff" not in df.columns:
        raise ValueError(
            "capability_diff not found in df and columns_for_diff is not provided. How to calculate capability difference?"
        )

    print("Weak-to-Strong: ", df[df["capability_diff"] < 0].ASR.mean())
    print("Strong-to-Weak: ", df[df["capability_diff"] > 0].ASR.mean())


def plot_judge_correlation_vs_metric(
    df,
    title="Judge Correlation vs Metric",
    metric_column="tinyMMLU_attacker",
    save_path=None,
):
    # Create a copy of the dataframe to avoid modifying the original
    df = df.copy()

    # Create label if not present
    if "label" not in df.columns:
        df["label"] = df["attacker_model_key"] + "→" + df["target_model_key"]

    # Get unique attackers and sort by the mean metric value
    unique_attackers = df["attacker_model_key"].unique()
    attacker_scores = df.groupby("attacker_model_key")[metric_column].mean()
    sorted_attackers = sorted(unique_attackers, key=lambda a: attacker_scores[a])

    # Get colors for each attacker model
    attacker_colors_list = get_palette(len(sorted_attackers))
    attacker_colors = dict(zip(sorted_attackers, attacker_colors_list))

    # Filter out rows with zero correlation if applicable
    if "judge_correlation" in df.columns:
        local_df = df[df["judge_correlation"] != 0]
    else:
        local_df = df

    # Prepare data for boxplots using the metric means as x positions
    data_to_plot = []
    metric_means = []
    for attacker in sorted_attackers:
        subset = local_df[local_df["attacker_model_key"] == attacker]
        if not subset.empty:
            data_to_plot.append(subset["judge_correlation"].values)
            metric_means.append(subset[metric_column].mean())

    # Create the plot
    fig, ax = plt.subplots(figsize=(12, 6))
    apply_spine_style(ax)  # Apply our spine style
    ax.set_axisbelow(True)
    ax.grid(True, linestyle="--", color="gray", alpha=0.7)

    # Create boxplots with positions set by the average metric values; adjust width as needed
    box_plots = ax.boxplot(
        data_to_plot,
        positions=metric_means,
        patch_artist=True,
        widths=0.05,
        vert=True,
        notch=False,
        showfliers=True,
        medianprops=dict(color="black", linewidth=1.5),
        flierprops=dict(marker="o", markerfacecolor="gray", markersize=4),
        whiskerprops=dict(linewidth=1.5),
        capprops=dict(linewidth=1.5),
    )

    # Set box colors for each attacker
    for patch, attacker in zip(box_plots["boxes"], sorted_attackers):
        patch.set_facecolor(attacker_colors[attacker])
        patch.set_edgecolor("black")
        patch.set_linewidth(1.5)
        patch.set_alpha(0.7)

    # Set plot title and labels
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.set_xlabel(metric_column, fontsize=12)
    ax.set_ylabel("Judge Correlation", fontsize=12)

    # Set y-axis ticks from 0 to 1 with step 0.2
    ax.set_yticks(np.arange(0, 1.1, 0.2))
    ax.set_ylim(-0.05, 1.05)

    # Pin the x-axis to [0, 1] with custom tick labels
    ax.set_xlim(0, 1)
    ticks = np.linspace(0, 1, 6)
    ax.set_xticks(ticks)
    ax.set_xticklabels([f"{x:.2f}" for x in ticks], rotation=45, ha="right")

    # Create a custom legend with colored rectangle patches for each attacker
    legend_handles = []
    for attacker in sorted_attackers:
        patch = plt.Rectangle(
            (0, 0), 1, 1, fc=attacker_colors[attacker], ec="black", lw=0.8
        )
        legend_handles.append(patch)
    ax.legend(
        handles=legend_handles,
        labels=sorted_attackers,
        title="Attacker Model",
        loc="upper left",
        bbox_to_anchor=(1, 1),
        borderaxespad=0,
        frameon=False,
    )

    plt.tight_layout()

    # Save the figure if a path is provided
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    return fig, ax


def plot_asr_vs_metric(
    df,
    title="ASR vs Metric",
    metric_column="tinyMMLU_attacker",
    save_path=None,
):
    # Create a copy of the dataframe to avoid modifying the original
    df = df.copy()

    # Create label if not present
    if "label" not in df.columns:
        df["label"] = df["attacker_model_key"] + "→" + df["target_model_key"]

    # Get unique attackers and sort them by the mean metric value
    unique_attackers = df["attacker_model_key"].unique()
    attacker_metric = df.groupby("attacker_model_key")[metric_column].mean()
    sorted_attackers = sorted(unique_attackers, key=lambda a: attacker_metric[a])

    # Get colors for each attacker model
    attacker_colors_list = get_palette(len(sorted_attackers))
    attacker_colors = dict(zip(sorted_attackers, attacker_colors_list))

    # Prepare data for box plots and corresponding x positions based on metric values
    data_to_plot = []
    metric_means = []
    for attacker in sorted_attackers:
        subset = df[df["attacker_model_key"] == attacker]
        if not subset.empty:
            data_to_plot.append(subset["ASR"].values)
            metric_means.append(subset[metric_column].mean())

    # Create the plot
    fig, ax = plt.subplots(figsize=(12, 6))
    apply_spine_style(ax)  # Apply our spine style
    ax.set_axisbelow(True)
    ax.grid(True, linestyle="--", color="gray", alpha=0.7)

    # Create boxplots with positions given by the average metric values; adjust width as needed
    box_plots = ax.boxplot(
        data_to_plot,
        positions=metric_means,
        patch_artist=True,
        widths=0.02,
        vert=True,
        notch=False,
        showfliers=True,
        medianprops=dict(color="black", linewidth=1.5),
        flierprops=dict(marker="o", markerfacecolor="gray", markersize=4),
        whiskerprops=dict(linewidth=1.5),
        capprops=dict(linewidth=1.5),
    )

    # Color each box corresponding to its attacker
    for patch, attacker in zip(box_plots["boxes"], sorted_attackers):
        patch.set_facecolor(attacker_colors[attacker])
        patch.set_edgecolor("black")
        patch.set_linewidth(1.5)
        patch.set_alpha(0.7)

    # Set plot title and axis labels
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.set_xlabel(metric_column, fontsize=12)
    ax.set_ylabel("ASR", fontsize=12)

    # Set y-axis limits (ASR is typically between 0 and 1)
    ax.set_ylim(-0.05, 1.05)
    # Pin x-axis to [0, 1] with custom tick labels
    ax.set_xlim(min(metric_means) - 0.1, max(metric_means) + 0.1)
    ticks = np.linspace(min(metric_means), max(metric_means), 6)
    ax.set_xticks(ticks)
    ax.set_xticklabels([f"{x:.2f}" for x in ticks], rotation=45, ha="right")

    # Create a custom legend using colored rectangle patches for each attacker
    legend_handles = []
    for attacker in sorted_attackers:
        patch = plt.Rectangle(
            (0, 0), 1, 1, fc=attacker_colors[attacker], ec="black", lw=0.8
        )
        legend_handles.append(patch)
    ax.legend(
        handles=legend_handles,
        labels=sorted_attackers,
        title="Attacker Model",
        loc="upper left",
        bbox_to_anchor=(1, 1),
        borderaxespad=0,
        frameon=False,
    )

    plt.tight_layout(rect=[0, 0, 0.85, 1])

    # Save the figure if a path is provided
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    return fig, ax


def plot_asr_heatmap(
    df,
    title="Heatmap of ASR (Attacker vs Target)",
    value_column="ASR",
    cmap="YlGnBu",
    annot=True,
    fmt=".2f",
    figsize=(10, 8),
    save_path=None,
    sort_by_capability=True,
    granulized=False,
    n_bins=10,
    transpose=False,
    custom_sort_order=None,
    symmetric=True,
    title_fontsize=14,
    xlabel_fontsize=15,
    ylabel_fontsize=15,
    annot_fontsize=9,
    tick_labelsize=15,
    avg_xlabel="Avg. Target ASR",
    avg_ylabel="Avg. Attacker ASR",
    annotate_values=False,
    sort_by_roles: dict | None = None,
    symmetric_axes: bool = True,
):
    # Create a copy of the dataframe to avoid modifying the original
    df = df.copy()

    # Get unique models from both attacker and target columns
    if custom_sort_order is not None:
        # If custom_sort_order is provided, use it without any additional sorting
        all_models = custom_sort_order
    else:
        # Otherwise get all unique models from the data
        all_models = sorted(
            set(df["target_model_key"].unique())
            | set(df["attacker_model_key"].unique())
        )

    # Create pivot table with all possible combinations
    if not transpose:
        # Default: Attackers on Y-axis (rows), Targets on X-axis (columns)
        heatmap_data = df.pivot_table(
            index="attacker_model_key",
            columns="target_model_key",
            values=value_column,
            aggfunc="mean",
        )
    else:
        # Transposed: Targets on Y-axis (rows), Attackers on X-axis (columns)
        heatmap_data = df.pivot_table(
            index="target_model_key",
            columns="attacker_model_key",
            values=value_column,
            aggfunc="mean",
        )

    for model in all_models:
        if model not in heatmap_data.index:
            heatmap_data.loc[model] = float("nan")

    # Add missing models as columns with NaN values
    for model in all_models:
        if model not in heatmap_data.columns:
            heatmap_data[model] = float("nan")

    # Consolidate and ensure all_models are present and provide a base order
    heatmap_data = heatmap_data.reindex(index=all_models, columns=all_models)

    # Determine and apply sort order
    if custom_sort_order is not None:
        processed_custom_order = [m for m in custom_sort_order if m in all_models]
        remaining_from_all = sorted(
            [m for m in all_models if m not in processed_custom_order]
        )
        final_row_order = processed_custom_order + remaining_from_all
        final_col_order = list(
            final_row_order
        )  # Custom sort implies symmetric axes for now
        heatmap_data = heatmap_data.reindex(
            index=final_row_order, columns=final_col_order
        )
    else:
        # Convert role-based sort_by_roles (attacker/target) to axis-based (rows/cols)
        sort_by_axes = {}
        if sort_by_roles and isinstance(sort_by_roles, dict):
            attacker_sort_instr = sort_by_roles.get("attacker")
            target_sort_instr = sort_by_roles.get("target")
            if not transpose:
                sort_by_axes["rows"] = attacker_sort_instr
                sort_by_axes["cols"] = target_sort_instr
            else:
                sort_by_axes["rows"] = (
                    target_sort_instr  # Target becomes rows when transposed
                )
                sort_by_axes["cols"] = attacker_sort_instr  # Attacker becomes columns

        # Helper function (assuming it's defined correctly above or will be adjusted)
        # ... _get_sorted_models_for_axis definition ...
        def _get_sorted_models_for_axis(
            models_on_axis,
            instruction,
            axis_role,
            original_df,
            heatmap_for_asr_calc,
            is_row_axis,
        ):
            if instruction == "capability":
                model_scores_cap = {}
                cap_attacker_col = "mmlu_pro_exact_match_custom_attacker"
                cap_target_col = "mmlu_pro_exact_match_custom_target"
                has_attacker_cap = cap_attacker_col in original_df.columns
                has_target_cap = cap_target_col in original_df.columns

                if not has_attacker_cap and not has_target_cap:
                    print(
                        f"Warning: Capability columns not found for '{instruction}' sort. Defaulting to alphabetical for axis: {axis_role}."
                    )
                    return sorted(list(models_on_axis))

                # Ensure Series are not empty before trying to get values
                target_scores_series = (
                    original_df.groupby("target_model_key")[cap_target_col].mean()
                    if has_target_cap
                    else pd.Series(dtype=float)
                )
                attacker_scores_series = (
                    original_df.groupby("attacker_model_key")[cap_attacker_col].mean()
                    if has_attacker_cap
                    else pd.Series(dtype=float)
                )

                for model in models_on_axis:
                    s_attacker = attacker_scores_series.get(model)
                    s_target = target_scores_series.get(model)
                    if s_attacker is not None:  # Prioritize attacker score
                        model_scores_cap[model] = s_attacker
                    elif s_target is not None:
                        model_scores_cap[model] = s_target

                # Sort ascending by capability (lower capability first, NaNs last)
                # User's last edit for capability sort was reverse=False
                return sorted(
                    models_on_axis,
                    key=lambda m: model_scores_cap.get(m, np.inf),
                    reverse=False,
                )

            elif instruction == "avg_asr":
                if heatmap_for_asr_calc.empty:
                    return sorted(list(models_on_axis))

                # Calculate means properly from the pivot table
                if is_row_axis:
                    # For rows, calculate mean across columns (axis=1)
                    asr_means = heatmap_for_asr_calc.mean(axis=1)
                    # Debug print for row ASR means
                    print(f"\nRow ASR means for {axis_role}:")
                    for model in models_on_axis:
                        mean_val = asr_means.get(model, np.nan)
                        if not pd.isna(mean_val):
                            print(f"{model}: {mean_val:.4f}")
                        else:
                            print(f"{model}: N/A")
                else:
                    # For columns, calculate mean across rows (axis=0)
                    asr_means = heatmap_for_asr_calc.mean(axis=0)
                    # Debug print for column ASR means
                    print(f"\nColumn ASR means for {axis_role}:")
                    for model in models_on_axis:
                        mean_val = asr_means.get(model, np.nan)
                        if not pd.isna(mean_val):
                            print(f"{model}: {mean_val:.4f}")
                        else:
                            print(f"{model}: N/A")

                # Sort by mean ASR in descending order, with NaN values last
                sorted_models = sorted(
                    models_on_axis,
                    key=lambda m: (pd.isna(asr_means.get(m)), -asr_means.get(m, 0)),
                )

                # Debug print for sorted order
                print(f"\nSorted order for {axis_role}:")
                for model in sorted_models:
                    mean_val = asr_means.get(model, np.nan)
                    if not pd.isna(mean_val):
                        print(f"{model}: {mean_val:.4f}")
                    else:
                        print(f"{model}: N/A")
                return sorted_models
            else:
                return sorted(list(models_on_axis))

        plot_row_axis_role = "attacker" if not transpose else "target"
        plot_col_axis_role = "target" if not transpose else "attacker"

        row_instruction = sort_by_axes.get("rows")
        col_instruction = sort_by_axes.get("cols")

        heatmap_data_for_asr = heatmap_data.copy()

        current_row_models = list(heatmap_data.index)
        current_col_models = list(heatmap_data.columns)

        # Always determine row order first
        final_row_order = _get_sorted_models_for_axis(
            current_row_models,
            row_instruction,
            plot_row_axis_role,
            df,
            heatmap_data_for_asr,
            is_row_axis=True,
        )

        # For columns, either use row order (if symmetric) or determine independently
        if symmetric_axes:
            final_col_order = list(
                final_row_order
            )  # Force column order to match row order
        else:
            # For independent column sorting, we need to ensure we're using the correct ASR means
            # Reindex heatmap_data_for_asr with final_row_order first to ensure correct ASR calculations
            heatmap_data_for_asr = heatmap_data_for_asr.reindex(index=final_row_order)
            # Ensure we're using the same data for both row and column calculations
            if row_instruction == "avg_asr" and col_instruction == "avg_asr":
                # Use the same heatmap_data_for_asr for both calculations
                final_col_order = _get_sorted_models_for_axis(
                    current_col_models,
                    col_instruction,
                    plot_col_axis_role,
                    df,
                    heatmap_data_for_asr,
                    is_row_axis=False,
                )
            else:
                final_col_order = _get_sorted_models_for_axis(
                    current_col_models,
                    col_instruction,
                    plot_col_axis_role,
                    df,
                    heatmap_data_for_asr,
                    is_row_axis=False,
                )

        # Apply both row and column ordering
        heatmap_data = heatmap_data.reindex(
            index=final_row_order, columns=final_col_order
        )

    # Drop all-NaN rows/columns AFTER all sorting and reindexing has been applied
    heatmap_data = heatmap_data.dropna(axis=0, how="all")
    heatmap_data = heatmap_data.dropna(axis=1, how="all")

    # Calculate mean values for rows and columns using the potentially non-symmetric heatmap_data
    if heatmap_data.empty:
        print(
            "Warning: Heatmap data is empty after filtering and sorting. Skipping mean calculation and plot."
        )
        # Optionally draw a blank plot with title
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title(title, fontsize=title_fontsize, fontweight="bold", y=1.02)
        if save_path:
            plt.savefig(save_path, dpi=1000)
        return fig, ax

    num_data_columns = len(
        heatmap_data.columns
    )  # Capture column count before adding mean column
    row_means = heatmap_data.mean(axis=1)
    col_means = heatmap_data.mean(axis=0)

    # clip to two decimal places
    row_means = row_means.round(2)
    col_means = col_means.round(2)
    # Add mean column
    heatmap_data[avg_xlabel] = row_means
    # Add mean row at the top
    mean_row = pd.DataFrame([col_means], columns=heatmap_data.columns)
    mean_row.index = [avg_ylabel]
    heatmap_data = pd.concat([mean_row, heatmap_data])

    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    apply_spine_style(ax)  # Apply our spine style

    if granulized:
        # Create discrete bins
        bins = np.linspace(0, 1, n_bins + 1)
        # Create discrete colormap
        cmap = plt.get_cmap(cmap, n_bins)

        # Create the heatmap without annotations
        sns.heatmap(
            heatmap_data,
            annot=annotate_values,
            cmap=cmap,
            vmin=0,
            vmax=1,
            ax=ax,
            # Create discrete colorbar with bin edges
            annot_kws={"fontsize": annot_fontsize},
            cbar_kws={
                "ticks": bins,
                # "label": "ASR", # Label set manually below
                "format": lambda x, _: f"{x:.1f}",
                # "orientation": "horizontal", # Back to default vertical
                # "location": "top", # Back to default right
            },
        )
        # Manually add label below the colorbar
        cbar = ax.collections[0].colorbar
        if cbar:
            cbar.set_label("")  # Remove default side label
            cbar.ax.text(
                0.5,
                -0.01,
                "ASR",
                transform=cbar.ax.transAxes,
                ha="center",
                va="top",
                fontsize=12,
            )  # Add label below

    else:
        # Create the regular heatmap with annotations
        sns.heatmap(
            heatmap_data,
            annot=annot,
            cmap=cmap,
            fmt=fmt,
            ax=ax,
            annot_kws={"fontsize": annot_fontsize},
            # cbar_kws={"orientation": "horizontal", "label": "ASR", "location": "top"}, # Back to default
            cbar_kws={"label": ""},  # Remove default label via cbar_kws
        )
        # Manually add label below the colorbar
        cbar = ax.collections[0].colorbar
        if cbar:
            # cbar.set_label("") # Already removed via cbar_kws
            cbar.ax.text(
                0.5,
                -0.08,
                "ASR",
                transform=cbar.ax.transAxes,
                ha="center",
                va="top",
                fontsize=ylabel_fontsize,
            )  # Add label below

    # Invert y-axis to have highest capability models at the top
    ax.invert_yaxis()

    # Set plot labels
    ax.set_title(title, fontsize=title_fontsize, fontweight="bold", y=1.02)

    # Set axis labels based on transpose option
    if not transpose:
        ax.set_xlabel("Target Model", fontsize=xlabel_fontsize)
        ax.set_ylabel("Attacker Model", fontsize=ylabel_fontsize)
    else:
        ax.set_xlabel("Attacker Model", fontsize=xlabel_fontsize)
        ax.set_ylabel("Target Model", fontsize=ylabel_fontsize)

    # Add lines to separate mean row and column
    ax.axhline(y=1, color="black", linewidth=2)
    ax.axvline(x=num_data_columns, color="black", linewidth=2)

    # Set tick label sizes
    ax.tick_params(axis="x", labelsize=tick_labelsize)
    ax.tick_params(axis="y", labelsize=tick_labelsize)

    # turn off the grid
    ax.grid(False)
    plt.tight_layout()

    # Save the figure if a path is provided
    if save_path:
        plt.savefig(save_path, dpi=1000)

    return fig, ax


def plot_pairwise_win_heatmap(
    df1,
    df2,
    title="Pairwise ASR Win Comparison",
    df1_name="DataFrame 1",
    df2_name="DataFrame 2",
    color1="lightblue",  # Color for df1 wins
    color2="lightcoral",  # Color for df2 wins
    tie_color="lightgrey",  # Color for tie or missing in both
    figsize=(12, 10),
    save_path=None,
    custom_sort_order=None,
    sort_by_capability=True,
    transpose=False,
    title_fontsize=14,
    xlabel_fontsize=12,
    ylabel_fontsize=12,
    tick_labelsize=10,
    legend_fontsize=10,
):
    """
    Generates a heatmap indicating which of two DataFrames (df1 or df2)
    has a higher ASR for each attacker-target model pair.

    Args:
        df1: First DataFrame, must contain 'attacker_model_key', 'target_model_key', 'ASR'.
        df2: Second DataFrame, must contain 'attacker_model_key', 'target_model_key', 'ASR'.
        title: Title of the plot.
        df1_name: Name for df1, used in legend.
        df2_name: Name for df2, used in legend.
        color1: Color indicating df1 has higher ASR.
        color2: Color indicating df2 has higher ASR.
        tie_color: Color for ties or when data is missing in both.
        figsize: Figure size.
        save_path: Path to save the figure. If None, figure is not saved.
        custom_sort_order: List of model names to define the order on axes.
        sort_by_capability: If True and custom_sort_order is None, try to sort models
                            by attempting to find common capability score columns.
        transpose: If True, transpose axes (target on Y, attacker on X).
        title_fontsize: Font size for the plot title.
        xlabel_fontsize: Font size for the X-axis label.
        ylabel_fontsize: Font size for the Y-axis label.
        tick_labelsize: Font size for tick labels.
        legend_fontsize: Font size for the legend.
    Returns:
        matplotlib.figure.Figure, matplotlib.axes.Axes
    """
    # Validate DataFrames
    required_cols = ["attacker_model_key", "target_model_key", "ASR"]
    for i, df in enumerate([df1, df2]):
        if not all(col in df.columns for col in required_cols):
            raise ValueError(
                f"DataFrame {i + 1} must contain columns: {', '.join(required_cols)}"
            )

    df1 = df1.copy()
    df2 = df2.copy()

    # Determine model keys for pivot table based on transpose
    if not transpose:
        pivot_index_key = "attacker_model_key"
        pivot_columns_key = "target_model_key"
        x_label_text = "Target Model"
        y_label_text = "Attacker Model"
    else:
        pivot_index_key = "target_model_key"
        pivot_columns_key = "attacker_model_key"
        x_label_text = "Attacker Model"
        y_label_text = "Target Model"

    # Determine final model order
    union_models = (
        set(df1[pivot_index_key].unique())
        | set(df1[pivot_columns_key].unique())
        | set(df2[pivot_index_key].unique())
        | set(df2[pivot_columns_key].unique())
    )

    final_model_order = []
    if custom_sort_order:
        final_model_order = [m for m in custom_sort_order if m in union_models]
        # Add any models from union_models not in custom_sort_order, sorted alphabetically
        remaining_models = sorted(list(union_models - set(final_model_order)))
        final_model_order.extend(remaining_models)
    elif sort_by_capability:
        model_scores = {}
        source_df_for_capability = None

        # Define common patterns for capability columns
        common_capability_patterns = [
            (
                "mmlu_pro_exact_match_custom_attacker",
                "mmlu_pro_exact_match_custom_target",
            ),
            ("avg_elo_attacker", "avg_elo_target"),
            ("capability_attacker", "capability_target"),
            ("attacker_capability", "target_capability"),
            ("attacker_score", "target_score"),
        ]

        found_capability_cols = False
        cap_metric_attacker_col_to_use = None
        cap_metric_target_col_to_use = None

        for cap_df, cap_df_name in [
            (df1, "df1"),
            (df2, "df2"),
        ]:  # Try df1 first, then df2
            for attacker_pattern, target_pattern in common_capability_patterns:
                if (
                    attacker_pattern in cap_df.columns
                    and target_pattern in cap_df.columns
                ):
                    source_df_for_capability = cap_df
                    cap_metric_attacker_col_to_use = attacker_pattern
                    cap_metric_target_col_to_use = target_pattern
                    found_capability_cols = True
                    # print(f"Using {cap_df_name} and columns '{attacker_pattern}', '{target_pattern}' for capability scores.") # Optional
                    break
            if found_capability_cols:
                break

        if source_df_for_capability is not None and found_capability_cols:
            scores_accumulator = {model: [] for model in union_models}

            # Attacker scores
            if pivot_index_key == "attacker_model_key":
                attacker_scores_series = source_df_for_capability.groupby(
                    pivot_index_key
                )[cap_metric_attacker_col_to_use].mean()
                for model, score in attacker_scores_series.items():
                    if model in scores_accumulator:
                        scores_accumulator[model].append(score)
            elif pivot_columns_key == "attacker_model_key":
                attacker_scores_series = source_df_for_capability.groupby(
                    pivot_columns_key
                )[cap_metric_attacker_col_to_use].mean()
                for model, score in attacker_scores_series.items():
                    if model in scores_accumulator:
                        scores_accumulator[model].append(score)

            # Target scores
            if pivot_columns_key == "target_model_key":
                target_scores_series = source_df_for_capability.groupby(
                    pivot_columns_key
                )[cap_metric_target_col_to_use].mean()
                for model, score in target_scores_series.items():
                    if model in scores_accumulator:
                        scores_accumulator[model].append(score)
            elif pivot_index_key == "target_model_key":
                target_scores_series = source_df_for_capability.groupby(
                    pivot_index_key
                )[cap_metric_target_col_to_use].mean()
                for model, score in target_scores_series.items():
                    if model in scores_accumulator:
                        scores_accumulator[model].append(score)

            for model in union_models:
                if scores_accumulator.get(model):
                    model_scores[model] = np.nanmean(scores_accumulator[model])
                else:
                    model_scores[model] = np.nan

            models_with_scores = {
                m
                for m in union_models
                if m in model_scores and not pd.isna(model_scores[m])
            }
            models_without_scores = union_models - models_with_scores

            sorted_with_scores = sorted(
                list(models_with_scores),
                key=lambda m: model_scores.get(m, -np.inf),
                reverse=True,
            )
            final_model_order = sorted_with_scores + sorted(list(models_without_scores))
        else:
            print(
                f"Warning: Could not find common capability columns in df1 or df2. "
                "Using alphabetical sort for models."
            )
            final_model_order = sorted(list(union_models))
    else:
        final_model_order = sorted(list(union_models))

    if not final_model_order:
        print("Warning: No models found to plot initially.")
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title(title, fontsize=title_fontsize, fontweight="bold")
        if save_path:
            plt.savefig(save_path, bbox_inches="tight")
        return fig, ax

    # Create pivot tables for ASR using the full final_model_order first
    asr1_pivot = df1.pivot_table(
        index=pivot_index_key, columns=pivot_columns_key, values="ASR", aggfunc="mean"
    )
    asr1_pivot = asr1_pivot.reindex(
        index=final_model_order, columns=final_model_order
    ).fillna(-1.0)

    asr2_pivot = df2.pivot_table(
        index=pivot_index_key, columns=pivot_columns_key, values="ASR", aggfunc="mean"
    )
    asr2_pivot = asr2_pivot.reindex(
        index=final_model_order, columns=final_model_order
    ).fillna(-1.0)

    # Filter to active models for rows and columns
    active_row_models = []
    for model_r in final_model_order:
        # Check if the row model_r has any data in either pivot table
        if not (
            (asr1_pivot.loc[model_r, :] == -1.0).all()
            and (asr2_pivot.loc[model_r, :] == -1.0).all()
        ):
            active_row_models.append(model_r)

    active_col_models = []
    for model_c in final_model_order:
        # Check if the column model_c has any data in either pivot table
        if not (
            (asr1_pivot.loc[:, model_c] == -1.0).all()
            and (asr2_pivot.loc[:, model_c] == -1.0).all()
        ):
            active_col_models.append(model_c)

    if not active_row_models or not active_col_models:
        print(
            "Warning: No active models found for rows or columns after filtering. Nothing to plot."
        )
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title(title, fontsize=title_fontsize, fontweight="bold")
        if save_path:
            plt.savefig(save_path, bbox_inches="tight")
        return fig, ax

    # Determine winner matrix using only active models
    winner_matrix = pd.DataFrame(
        index=active_row_models, columns=active_col_models, dtype=int
    )
    for r_model in active_row_models:  # Row iterator
        for c_model in active_col_models:  # Column iterator
            asr1 = asr1_pivot.loc[r_model, c_model]
            asr2 = asr2_pivot.loc[r_model, c_model]

            if (
                asr1 == -1.0 and asr2 == -1.0
            ):  # Data for this specific pair missing in both
                winner_matrix.loc[r_model, c_model] = 2  # Tie / Pair data missing
            elif asr1 > asr2:  # df1 wins (asr2 can be -1.0 if asr1 is not)
                winner_matrix.loc[r_model, c_model] = 0
            elif asr2 > asr1:  # df2 wins (asr1 can be -1.0 if asr2 is not)
                winner_matrix.loc[r_model, c_model] = 1
            else:  # asr1 == asr2 (and not both -1.0, covered above)
                winner_matrix.loc[r_model, c_model] = 2  # Tie

    # Debugging: Print win counts
    df1_wins = (winner_matrix == 0).sum().sum()
    df2_wins = (winner_matrix == 1).sum().sum()
    ties = (winner_matrix == 2).sum().sum()
    print(f"Debugging win counts for {df1_name} vs {df2_name}:")
    print(f"  {df1_name} wins: {df1_wins}")
    print(f"  {df2_name} wins: {df2_wins}")
    print(f"  Ties / Pair data missing: {ties}")

    # Plotting
    fig, ax = plt.subplots(figsize=figsize)
    apply_spine_style(ax)  # Assuming apply_spine_style is defined elsewhere

    custom_cmap = matplotlib.colors.ListedColormap([color1, color2, tie_color])
    # BoundaryNorm maps discrete values 0, 1, 2 to the three colors
    norm = matplotlib.colors.BoundaryNorm([-0.5, 0.5, 1.5, 2.5], custom_cmap.N)

    sns.heatmap(
        winner_matrix,
        annot=False,  # No numerical annotations in cells
        cmap=custom_cmap,
        norm=norm,
        cbar=False,  # No color bar
        ax=ax,
        linewidths=0.5,  # Remove grid lines
        linecolor="black",  # Remove grid lines
    )

    ax.set_title(
        title, fontsize=title_fontsize, fontweight="bold", y=1.0, x=0.2
    )  # y adjusted if legend moves above
    ax.set_xlabel(x_label_text, fontsize=xlabel_fontsize)
    ax.set_ylabel(y_label_text, fontsize=ylabel_fontsize)

    ax.tick_params(axis="x", labelsize=tick_labelsize, rotation=90)
    ax.tick_params(axis="y", labelsize=tick_labelsize)
    ax.grid(False)

    # Ensure y-axis is not inverted by default heatmap behavior if not desired
    # For typical matrix representation (attacker=rows, target=cols), top-left origin is often standard.
    # sns.heatmap often inverts y-axis by default. If capability sort is highest first, it should appear at top.
    if (sort_by_capability or custom_sort_order) and active_row_models:
        ax.invert_yaxis()  # Ensure high-capability models are at the top if sorted that way

    # Create legend
    patches = [
        mpatches.Patch(color=color1, label=f"{df1_name} wins"),
        mpatches.Patch(color=color2, label=f"{df2_name} wins"),
        mpatches.Patch(color=tie_color, label="Tie"),
    ]
    ax.legend(
        handles=patches,
        bbox_to_anchor=(0.7, 1.0),  # Centered horizontally, above the plot
        loc="lower center",  # Anchor point of the legend box
        ncol=3,  # Arrange in 3 columns
        borderaxespad=0.0,
        fontsize=legend_fontsize,
        frameon=False,
    )

    plt.tight_layout(
        rect=[0, 0, 1, 0.95]
    )  # Adjust rect to make space for legend at the top

    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)

    return fig, ax
