"""
File to make plots. By default, it assumes the `optimize-then-sample` schedule.
To compare different schedules, use the `make_plots_schedule.py` file.
"""
import hashlib
import logging
import math
import os

from matplotlib.colors import LogNorm, PowerNorm
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import pandas as pd
from scipy.interpolate import griddata, interp1d
import seaborn as sns

import scienceplots
from config import ATTACKS, MODELS, METRIC, GROUP_BY, DATASET_IDX
from src.io_utils import collect_results, get_filtered_and_grouped_paths

plt.style.use("science")

logging.basicConfig(level=logging.INFO)

pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)

def generate_sample_sizes(total_samples: int) -> tuple[int, ...]:
    if total_samples < 1:
        return tuple()
    bases = (1, 2, 5)
    result = []
    power = 0
    while True:
        scale = 10 ** power
        for b in bases:
            value = b * scale
            if value > total_samples:
                result.append(total_samples) if result[-1] != total_samples else None
                return tuple(result)
            result.append(value)
            if value == total_samples:
                return tuple(result)
        power += 1

def _pareto_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    Return the non-dominated (Pareto-optimal) points, ordered by cost.
    The frontier is defined as points for which no other point has
    *both* lower cost (x) and lower mean p_harmful (y).

    Parameters
    ----------
    xs, ys : 1-D arrays of equal length
        Coordinates of the candidate points.

    Returns
    -------
    frontier_xs, frontier_ys : 1-D arrays
        Coordinates of the Pareto frontier, sorted by xs ascending.
    """
    order = np.argsort(xs)
    xs_sorted, ys_sorted = xs[order], ys[order]

    frontier_x, frontier_y = [0], [0]
    best_y_so_far = 0
    for x_val, y_val in zip(xs_sorted, ys_sorted):
        if y_val > best_y_so_far:
            frontier_x.append(x_val)
            frontier_y.append(y_val)
            best_y_so_far = y_val
    frontier_x.append(xs_sorted[-1])
    frontier_y.append(frontier_y[-1])
    return np.asarray(frontier_x), np.asarray(frontier_y)

def _non_cumulative_dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    Return all points ordered by cost, without dominance filtering.
    This creates a non-cumulative frontier that includes all points.

    Parameters
    ----------
    xs, ys : 1-D arrays of equal length
        Coordinates of the candidate points.

    Returns
    -------
    frontier_xs, frontier_ys : 1-D arrays
        All points, sorted by xs ascending.
    """
    order = np.argsort(xs)
    xs_sorted, ys_sorted = xs[order], ys[order]

    frontier_x, frontier_y = [0, *xs_sorted], [0, *ys_sorted]

    return np.asarray(frontier_x), np.asarray(frontier_y)

DATA_CACHE = {}
def fetch_data(model: str, attack: str, attack_params: dict, dataset_idx: list[int], group_by: set[str]):
    """Common data fetching logic used across all plotting functions."""
    hash_key = hashlib.sha256((model + attack + str(attack_params) + str(dataset_idx) + str(group_by)).encode()).hexdigest()
    if hash_key in DATA_CACHE:
        return DATA_CACHE[hash_key]

    filter_by = dict(
        model=model,
        attack=attack,
        attack_params=attack_params,
        dataset_params={"idx": dataset_idx},
    )
    paths = get_filtered_and_grouped_paths(filter_by, group_by, force_reload=False)

    results = collect_results(paths, infer_sampling_flops=True)
    assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}"
    DATA_CACHE[hash_key] = list(results.values())[0]
    return DATA_CACHE[hash_key]

def preprocess_data(results: dict[str, np.ndarray], metric: tuple[str, ...], threshold: float | None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Common data preprocessing logic.

    Returns:
        y: np.ndarray
            (B, n_steps, n_samples)
        flops_optimization: np.ndarray
            (B, n_steps)
        flops_sampling_prefill_cache: np.ndarray
            (B, n_steps)
        flops_sampling_generation: np.ndarray
            (B, n_steps)
    """

    y = np.array(results[metric], dtype=np.float32)
    if threshold is not None:
        y = y > threshold

    flops_sampling_prefill_cache = np.array(results["flops_sampling_prefill_cache"])
    flops_sampling_generation = np.array(results["flops_sampling_generation"])
    flops_optimization = np.array(results["flops"])

    return y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation

def subsample_and_aggregate(step_idx: int, sample_count: int, cumulative: bool, y: np.ndarray,
                            opt_flops: np.ndarray, sampling_prefill_flops: np.ndarray,
                            sampling_generation_flops: np.ndarray, rng: np.random.Generator,
                            return_ratio: bool = False, n_smoothing: int = 1):
    """
    Unified subsampling and aggregation function.
    This function takes the (B, T, S) sized arrays and computes

    Parameters
    ----------
    step_idx : int
        Step index to aggregate over
    sample_count : int
        Sample count to aggregate over
    cumulative : bool
        If True, the aggregation is cumulative
    y : np.ndarray
        (B, n_steps, n_samples)
    opt_flops : np.ndarray
        (B, n_steps)
    sampling_prefill_flops : np.ndarray
        (B, n_steps)
    sampling_generation_flops : np.ndarray
        (B, n_steps)
    return_ratio : bool
        If True, returns sampling ratio instead of total cost
    n_smoothing : int
        Number of smoothing iterations for variance reduction
    """
    n_runs, n_steps, n_total_samples = y.shape
    opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1))
    sampling_flop = np.mean(sampling_generation_flops[:, step_idx]) * sample_count + np.mean(sampling_prefill_flops[:, step_idx])
    total_flop = opt_flop + sampling_flop

    values = []
    for _ in range(n_smoothing):
        rng = np.random.default_rng(sample_count + _)
        if cumulative and step_idx > 0:
            samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)][...,0].max(axis=1)
            samples_at_end = y[:, step_idx, rng.choice(n_total_samples, size=sample_count, replace=False)].max(axis=-1)
            values.append(np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1).mean(axis=0))
        else:
            values.append(y[:, step_idx, rng.choice(n_total_samples, size=sample_count, replace=False)].max(axis=-1).mean(axis=0))

    mean_value = np.mean(values)
    return (total_flop, step_idx, sample_count, mean_value)


def get_points(y: np.ndarray, opt_flops: np.ndarray, sampling_prefill_flops: np.ndarray,
               sampling_generation_flops: np.ndarray,
               n_smoothing: int = 1, cumulative: bool = False):
    """Generate points for plotting with optional ratio calculation."""
    n_runs, n_steps, total_samples = y.shape
    rng = np.random.default_rng(42)
    pts = []

    for j in range(1, total_samples + 1, 1):
        for i in range(0, n_steps, 1):
            pts.append(subsample_and_aggregate(
                i, j, cumulative, y, opt_flops, sampling_prefill_flops,
                sampling_generation_flops, rng, n_smoothing
            ))
    return np.asarray(pts)


def setup_color_normalization(color_scale: str, values: np.ndarray):
    """Setup color normalization based on scale type."""
    if color_scale == "log":
        return LogNorm(values.min(), values.max())
    elif color_scale == "sqrt":
        return PowerNorm(gamma=0.5, vmin=values.min(), vmax=values.max())
    else:
        return plt.Normalize(values.min(), values.max())


def pareto_plot(
    results: dict[str,np.ndarray],
    baseline: dict[str,np.ndarray] | None = None,
    title: str = "Pareto Frontier",
    sample_levels_to_plot: tuple[int, ...]|None = None,
    metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'),
    plot_points: bool = True,
    plot_frontiers: bool = True,
    plot_envelope: bool = False,
    verbose: bool = True,
    cumulative: bool = False,
    n_x_points: int = 1000,
    x_scale="log",
    threshold: float|None = None,
    color_scale: str = "linear",
):
    """
    Scatter the full design-space AND overlay Pareto frontiers
    for selected sampling counts.
    """
    y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data(
        results, metric, threshold
    )
    n_runs, n_steps, n_total_samples = y.shape
    if sample_levels_to_plot is None:
        sample_levels_to_plot = generate_sample_sizes(n_total_samples)

    pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation,
                     cumulative=cumulative)
    cost, step_idx, n_samp, mean_p = pts.T
    max_cost = max(cost)
    if x_scale == "log":
        x_interp = np.logspace(11, np.log10(max_cost)+0.001, n_x_points)
    else:
        x_interp = np.linspace(0, max_cost+1, n_x_points)

    plt.figure(figsize=(5.4, 2.4))
    plt.subplot2grid((2, 3), (0, 1), colspan=2, rowspan=2)

    color_norm = setup_color_normalization(color_scale, n_samp)
    if plot_points:
        if len(cost) > 1000:
            log_cost = np.log10(cost + 1e-10)
            log_indices = np.argsort(log_cost)
            step = len(log_indices) // 1000
            subsample_indices = log_indices[::step][:1000]

            cost_sub = cost[subsample_indices]
            mean_p_sub = mean_p[subsample_indices]
            n_samp_sub = n_samp[subsample_indices]
        else:
            cost_sub = cost
            mean_p_sub = mean_p
            n_samp_sub = n_samp

        plt.scatter(cost_sub, mean_p_sub, c=n_samp_sub, cmap="viridis", alpha=0.15, s=3, norm=color_norm)

    cmap = plt.get_cmap("viridis")
    rng = np.random.default_rng()

    n_smoothing = 50
    level_colors: dict[int, tuple[float, float, float, float]] = {}
    plotted_levels: set[int] = set()
    frontier_data: dict[int, dict[str, np.ndarray | float | tuple[float, float, float, float]]] = {}

    if plot_frontiers:
        for j in sample_levels_to_plot:
            xs = []
            ys = []
            if j == n_total_samples:
                n_smoothing = 1
            for _ in range(n_smoothing):
                pts = []
                for i in range(0, n_steps, 1):
                    pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, rng))

                pts = np.asarray(pts)
                cost, _, _, mean_p = pts.T

                fx, fy = _pareto_frontier(cost, mean_p)
                xs.append(fx)
                ys.append(fy)

            y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
            color = cmap(color_norm(j))
            y_mean = np.nanmean(y_interp, axis=0)
            valid_mask = ~np.isnan(y_mean) & (y_mean > 0)

            x_pts = x_interp[valid_mask]
            y_pts = y_mean[valid_mask]
            plt.plot(
                x_pts,
                y_pts,
                marker="o",
                linewidth=1.8,
                markersize=2,
                label=f"{j} samples",
                color=color,
            )
            level_colors[j] = color
            plotted_levels.add(j)
            frontier_data[j] = {
                "x": x_pts,
                "y": y_pts,
                "color": color,
                "max_asr": float(np.max(y_pts)) if np.any(valid_mask) else 0.0,
            }

    if plot_envelope:
        n_smoothing = n_total_samples
        y_interps = []
        for j in range(1, n_total_samples+1):
            xs = []
            ys = []
            for n in range(n_smoothing):
                pts = []
                for i in range(0, n_steps, 1):
                    pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, rng))

                pts = np.asarray(pts)
                cost, step_idx, n_samp, mean_p = pts.T

                fx, fy = _pareto_frontier(cost, mean_p)
                xs.append(fx)
                ys.append(fy)

            y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
            y_interps.append(np.nanmean(y_interp, axis=0))
        y_interps = np.array(y_interps)
        argmax = np.nanargmax(y_interps, axis=0)
        argmax = np.maximum.accumulate(argmax)
        y_envelope = np.nanmax(y_interps, axis=0)

        valid_mask = ~np.isnan(y_envelope) & (y_envelope > 0)
        color = [cmap(color_norm(argmax[i])) for i in range(len(argmax)) if valid_mask[i]]
        plt.scatter(x_interp[valid_mask], y_envelope[valid_mask], c=color, s=2)

    baseline_max_asr = 0.0
    baseline_frontier_data: dict[str, np.ndarray | float] | None = None

    if baseline is not None:
        y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data(
            baseline, metric, threshold
        )

        if y_baseline is not None:
            if verbose:
                logging.info(f"{n_runs} for main")
                logging.info(f"{y_baseline.shape[0]} for baseline")
            n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape
            assert n_total_samples_baseline == 1

            pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache,
                           baseline_flops_sampling_generation, cumulative=cumulative)
            cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts.T
            max_cost_baseline = max(cost_baseline)

            if plot_frontiers or plot_envelope:
                mask = n_samp_baseline == 1
                fx, fy = _pareto_frontier(cost_baseline[mask], mean_p_baseline[mask])
                y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp)
                if max_cost_baseline / max_cost < 0.7:
                    max_cost_baseline = max_cost
                valid_mask_baseline = ~np.isnan(y_interp_baseline) & (y_interp_baseline > 0) & (x_interp < max_cost_baseline)
                # Store baseline data for bar charts
                baseline_max_asr = np.max(y_interp_baseline[valid_mask_baseline]) if np.any(valid_mask_baseline) else 0
                baseline_frontier_data = {
                    'x': x_interp[valid_mask_baseline],
                    'y': y_interp_baseline[valid_mask_baseline],
                    'max_asr': baseline_max_asr
                }
                plt.plot(
                    x_interp[valid_mask_baseline],
                    y_interp_baseline[valid_mask_baseline],
                    marker="o",
                    linewidth=1.8,
                    markersize=2,
                    label=f"Baseline (greedy)",
                    color="r",
                )

    plt.xlabel("Total FLOPs", fontsize=13)
    if threshold is None:
        plt.ylabel(r"$\mathcal{H}_b$", fontsize=18)
    else:
        plt.ylabel(r"$\text{ASR}_b$", fontsize=18)
    plt.grid(True, alpha=0.3)
    plt.xscale(x_scale)

    handles, labels = plt.gca().get_legend_handles_labels()
    # Create legend subplot and move legend there
    ax0 = plt.subplot2grid((2, 3), (0, 0), colspan=1, rowspan=2)
    ax0.axis('off')  # Remove all axes
    # Get legend from current plot and move to ax0
    handles = [*handles[:-1][::-1], handles[-1]]
    labels = [*labels[:-1][::-1], labels[-1][:-9]]
    ax0.legend(handles, labels, loc='center', fontsize=12)
    plt.tight_layout()
    os.makedirs(f"evaluate/pareto_plots", exist_ok=True)
    if threshold is None:
        plt.savefig(f"evaluate/pareto_plots/{title.replace(' ', '_')}.pdf")
    else:
        plt.savefig(f"evaluate/pareto_plots/{title.replace(' ', '_')}_t={threshold}.pdf")
    plt.close()

    fig = plt.figure(figsize=(6, 5))

    bar_chart_margin_multiplier = 5
    methods_flops = []
    flops_required = []
    colors_flops = []

    # Find FLOPs required to reach baseline ASR for each sampling method
    target_asr = baseline_max_asr

    for j in sample_levels_to_plot:
        if j in frontier_data:
            # Find the minimum FLOPs where ASR >= target_asr
            y_vals = frontier_data[j]['y']
            x_vals = frontier_data[j]['x']

            # Find points where ASR >= target_asr
            valid_indices = y_vals >= target_asr
            if np.any(valid_indices):
                min_flops = np.min(x_vals[valid_indices])
                methods_flops.append(f"{j} samples")
                flops_required.append(min_flops)
                colors_flops.append(frontier_data[j]['color'])

    # Add baseline (find minimum FLOPs where it reaches target ASR)
    if baseline_frontier_data['x'].size > 0:
        baseline_y_vals = baseline_frontier_data['y']
        baseline_x_vals = baseline_frontier_data['x']
        baseline_valid_indices = baseline_y_vals >= target_asr
        if np.any(baseline_valid_indices):
            baseline_flops = np.min(baseline_x_vals[baseline_valid_indices])
        else:
            # Fallback to minimum FLOPs if no point reaches target ASR
            baseline_flops = np.min(baseline_x_vals)
        methods_flops.insert(0, "Baseline")
        flops_required.insert(0, baseline_flops)
        colors_flops.insert(0, "red")

    # ---------- FLOPs Efficiency to Reach Baseline ASR ----------
    def add_flops_bar_chart():
        ax3 = plt.subplot2grid((2, 2), (0, 1))
        plt.title("FLOPs to match baseline", fontsize=14)
        plt.bar(methods_flops, flops_required, color=colors_flops, alpha=0.7, edgecolor='black')
        plt.ylabel("FLOPs", fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.yscale('log')
        plt.grid(True, alpha=0.3, axis='y')

        ymin, ymax = plt.ylim()
        margin = ((math.log10(ymax) - math.log10(ymin)) * 0.2)
        plt.ylim(ymin, ymax * (1+margin))

    add_flops_bar_chart()

    # ---------- Speedup vs Baseline ----------
    def add_speedup_bar_chart():
        ax4 = plt.subplot2grid((2, 2), (1, 0))

        # Create speedup plot
        speedup_methods = []
        speedups = []
        speedup_colors = []

        # Calculate speedup for each method (baseline_flops / method_flops)
        baseline_flops = flops_required[0] if methods_flops and methods_flops[0] == "Baseline" else None

        if baseline_flops is not None:
            for i, (method, flops, color) in enumerate(zip(methods_flops, flops_required, colors_flops)):
                if method != "Baseline":  # Skip baseline itself
                    speedup = baseline_flops / flops if flops > 0 else 0
                    speedup_methods.append(method)
                    speedups.append(speedup)
                    speedup_colors.append(color)

            if speedup_methods:
                bars = plt.bar(speedup_methods, speedups, color=speedup_colors, alpha=0.7, edgecolor='black')
                plt.ylabel("Speedup (FLOPs)", fontsize=12)
                plt.xticks(rotation=45, ha='right')
                plt.grid(True, alpha=0.3, axis='y')

                # Add horizontal line at y=1 for reference
                plt.axhline(y=1, color='red', linestyle='--', alpha=0.7, linewidth=1)

                # Increase ylim by small margin
                ymin, ymax = plt.ylim()
                margin = (ymax - ymin) * 0.05 * bar_chart_margin_multiplier
                plt.ylim(max(0, ymin - margin), ymax + margin)

                # Add value labels on bars
                for bar, value in zip(bars, speedups):
                    ax4.annotate(f'{value:.1f}x',
                                xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                                xytext=(0, 5),
                                textcoords='offset points',
                                ha='center', va='bottom', fontsize=10)

    add_speedup_bar_chart()

    # --------- ASR @ matched FLOPs ----------
    def add_asr_at_max_greedy_flops_bar_chart():
        ax5 = plt.subplot2grid((2, 2), (1, 1))

        methods = []
        max_asrs = []
        colors = []

        # Add baseline (delta = 0 for baseline)
        methods.append("Baseline")
        max_asrs.append(0.0)  # Delta from itself is 0
        colors.append("red")

        # Add sampling methods (calculate delta from baseline)
        for j in sample_levels_to_plot:
            if j in frontier_data:
                methods.append(f"{j} samples" if j != 1 else "1 sample")
                if baseline_frontier_data["x"].size == 0:
                    continue
                baseline_max_flops = baseline_frontier_data['x'][-1]
                x_idx_of_same_flops_as_baseline = np.argmax(frontier_data[j]['x'] > baseline_max_flops) - 1
                delta_asr = frontier_data[j]['y'][x_idx_of_same_flops_as_baseline] - baseline_max_asr if baseline_frontier_data is not None else 0
                max_asrs.append(delta_asr)
                colors.append(frontier_data[j]['color'])

        bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black')
        if threshold is None:
            plt.ylabel(r"$\Delta$ $\mathcal{H}_b$" , fontsize=14)
        else:
            plt.ylabel(r"$\Delta \text{ASR}_b$", fontsize=14)

        plt.xticks(rotation=45, ha='right')
        plt.title("ASR at matched FLOPs", fontsize=14)
        plt.grid(True, alpha=0.3, axis='y')

        ymin, ymax = plt.ylim()
        margin = (ymax - ymin) * 0.03 * bar_chart_margin_multiplier
        plt.ylim(ymin - margin, ymax + margin)

        for bar, value in zip(bars, max_asrs):
            offset_pt = 4
            va = 'bottom' if value >= 0 else 'top'
            offset = (0, offset_pt if value >= 0 else -offset_pt)

            ax5.annotate(f'{value:.2f}',
                        xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                        xytext=offset,
                        textcoords='offset points',
                        ha='center', va=va, fontsize=10)
    add_asr_at_max_greedy_flops_bar_chart()
    plt.tight_layout()
    os.makedirs(f"evaluate/bar_charts", exist_ok=True)
    if threshold is None:
        plt.savefig(f"evaluate/bar_charts/{title.replace(' ', '_')}.pdf")
    else:
        plt.savefig(f"evaluate/bar_charts/{title.replace(' ', '_')}_t={threshold}.pdf")
    plt.close()
    # Create a separate figure for just the legend
    fig_legend = plt.figure(figsize=(4, 1))
    ax_legend = fig_legend.add_subplot(111)
    ax_legend.axis('off')

    legend_elements = []
    legend_elements.append(plt.Line2D([0], [0], color="red", linewidth=2,
                                        label="Baseline (Greedy)"))
    cmap = plt.get_cmap("viridis")
    color_norm = setup_color_normalization("linear", np.array(sample_levels_to_plot))

    for j in sample_levels_to_plot:
        if j in frontier_data:
            color = cmap(color_norm(j))
            legend_elements.append(plt.Line2D([0], [0], color=color, linewidth=2,
                                            label=f"{j} samples"))

    # Create horizontal legend
    ax_legend.legend(handles=legend_elements, loc='center', ncol=len(legend_elements),
        fontsize=10, frameon=False, columnspacing=1.0, handletextpad=0.5)

    plt.tight_layout()
    plt.savefig(f"evaluate/pareto_plots/legend_{n_total_samples}.pdf", bbox_inches='tight')
    plt.close()



def multi_attack_non_cumulative_pareto_plot(
    attacks_data: dict,
    model_title: str,
    title: str = "Multi-Attack Non-Cumulative Pareto",
    metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'),
    threshold: float|None = None,
    n_x_points: int = 1000,
    verbose: bool = True,
    target_samples: int = 50,
):
    """
    Create a non-cumulative Pareto plot showing multiple attacks on the same axes.
    Each attack shows its frontier with 50 samples.
    """

    attack_colors = {
        "gcg": "#1f77b4",
        "gcg_reinforce": "#87CEEB",
        "autodan": "#ff7f0e",
        "beast": "#2ca02c",
        "pair": "#d62728",
        "bon": "#9467bd",
        "direct": "#8c564b",
    }

    plt.figure(figsize=(4, 3))

    desired_attacks = {"PAIR", "BEAST", "AutoDAN", "GCG", "REINFORCE_GCG"}
    filtered_attacks_data = {}

    for config_key, (results, config) in attacks_data.items():
        if config_key in desired_attacks:
            filtered_attacks_data[config_key] = (results, config)

    attacks_data = filtered_attacks_data

    if not attacks_data:
        logging.warning("No desired attacks found in data")
        return

    x_interp = np.linspace(0, 100, n_x_points)

    rng = np.random.default_rng(42)

    for config_key, (results, config) in attacks_data.items():
        y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data(
            results, metric, threshold, None
        )
        n_runs, n_steps, n_total_samples = y.shape

        original_attack_name = None
        for atk_name, atk_cfg in ATTACKS:
            if atk_cfg.get('title_suffix') == config_key:
                original_attack_name = atk_name
                break

        color = attack_colors.get(original_attack_name, "black")

        n_smoothing = 1
        xs = []
        ys = []

        for _ in range(n_smoothing):
            pts = []
            for i in range(0, n_steps, 1):
                pts.append(subsample_and_aggregate(i, min(target_samples, n_total_samples), False, y,
                                                 flops_optimization, flops_sampling_prefill_cache,
                                                 flops_sampling_generation, rng))

            pts = np.asarray(pts)
            cost, step_idx_pts, _, mean_p = pts.T

            step_percentages = (step_idx_pts / (n_steps - 1)) * 100

            fx, fy = _non_cumulative_dominance_frontier(step_percentages, mean_p)
            xs.append(fx)
            ys.append(fy)

        y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False,
                           fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)]
        y_mean = np.nanmean(y_interp, axis=0)

        valid_indices = ~np.isnan(y_mean)
        if np.any(valid_indices):
            first_valid_value = y_mean[valid_indices][0]
            y_delta = y_mean - first_valid_value

            valid_mask = ~np.isnan(y_delta)
            if np.any(valid_mask):
                label = f"{config_key}".replace("_", " ").replace(" SR", "")

                plt.plot(x_interp[valid_mask], y_delta[valid_mask],
                        linewidth=1.2,
                        label=label, color=color)

    plt.xlabel(r"Optimization Progress (\%)", fontsize=15)
    if threshold is None:
        if target_samples == 1:
            plt.ylabel(r"$\Delta$ $\mathcal{H}_{q}@1$", fontsize=16)
        elif target_samples == 50:
            plt.ylabel(r"$\Delta$ $\mathcal{H}_{q}@50$", fontsize=16)
        else:
            plt.ylabel(r"$\Delta$ $\mathcal{H}_{q}@n$", fontsize=16)
    else:
        if target_samples == 1:
            plt.ylabel(r"$\Delta$ $\text{ASR}_{q}@1$", fontsize=16)
        elif target_samples == 50:
            plt.ylabel(r"$\Delta$ $\text{ASR}_{q}@50$", fontsize=16)
        else:
            plt.ylabel(r"$\Delta$ $\text{ASR}_{q}@n$", fontsize=16)

    plt.grid(True, alpha=0.3)
    plt.xlim(0, 100)
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
    plt.title(f"{model_title}", fontsize=15)

    handles, labels = plt.gca().get_legend_handles_labels()
    sorted_pairs = sorted(zip(handles, labels), key=lambda x: x[1])
    handles, labels = zip(*sorted_pairs) if sorted_pairs else ([], [])

    plt.tight_layout()
    os.makedirs(f"evaluate/multi_attack_non_cumulative_pareto_plots", exist_ok=True)
    if threshold is None:
        plt.savefig(f"evaluate/multi_attack_non_cumulative_pareto_plots/{title.replace(' ', '_')}_{target_samples}.pdf")
    else:
        plt.savefig(f"evaluate/multi_attack_non_cumulative_pareto_plots/{title.replace(' ', '_')}_t={threshold}_{target_samples}.pdf")
    plt.close()

    fig_legend = plt.figure(figsize=(5, 1))
    ax_legend = fig_legend.add_subplot(111)
    ax_legend.axis('off')

    ax_legend.legend(handles=handles, loc='center', ncol=5,
        fontsize=12, frameon=False, columnspacing=1.0, handletextpad=0.5)

    plt.tight_layout()
    if threshold is None:
        plt.savefig(f"evaluate/multi_attack_non_cumulative_pareto_plots/legend_{title.replace(' ', '_')}.pdf", bbox_inches='tight')
    else:
        plt.savefig(f"evaluate/multi_attack_non_cumulative_pareto_plots/legend_{title.replace(' ', '_')}_t={threshold}.pdf", bbox_inches='tight')
    plt.close()

    if verbose:
        logging.info(f"Multi-attack non-cumulative Pareto plot saved for {model_title}")
        logging.info(f"Legend saved separately")

def flops_breakdown_plot(
    results: dict[str,np.ndarray],
    title: str = "FLOPs Breakdown Analysis",
    sample_levels_to_plot: tuple[int, ...]|None = None,
    metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'),
    cumulative: bool = False,
    threshold: float|None = None,
    verbose: bool = True,
):
    """
    Plot optimization FLOPs vs sampling FLOPs with p_harmful as a 2D surface.
    """
    y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data(
        results, metric, threshold
    )
    n_runs, n_steps, n_total_samples = y.shape
    if sample_levels_to_plot is None:
        sample_levels_to_plot = generate_sample_sizes(n_total_samples)

    opt_flops = []
    sampling_flops = []
    p_harmful_vals = []
    n_samples_vals = []

    rng = np.random.default_rng(42)

    for j in range(1, n_total_samples + 1, 1):
        for i in range(0, n_steps, 1):
            opt_flop = np.mean(flops_optimization[:, :i+1].sum(axis=1))
            sampling_flop = np.mean(flops_sampling_generation[:, i]) * j + np.mean(flops_sampling_prefill_cache[:, i])
            p_vals = []
            n_smoothing = 1
            for n in range(n_smoothing):
                if cumulative and i > 0:
                    samples_up_to_now = y[:, :i, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0]
                    samples_at_end = y[:, i, rng.choice(n_total_samples, size=j, replace=False)].max(axis=-1)
                    p_val = np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1).mean(axis=0)
                else:
                    p_val = y[:, i, rng.choice(n_total_samples, size=j, replace=False)].max(axis=-1).mean(axis=0)
                p_vals.append(p_val)

            opt_flops.append(opt_flop + sampling_flop)
            sampling_flops.append(sampling_flop)
            p_harmful_vals.append(np.mean(p_vals))
            n_samples_vals.append(j)

    opt_flops = np.array(opt_flops)
    sampling_flops = np.array(sampling_flops)
    p_harmful_vals = np.array(p_harmful_vals)
    n_samples_vals = np.array(n_samples_vals)

    plt.figure(figsize=(4, 2.8))

    sampling_min, sampling_max = sampling_flops.min(), sampling_flops.max()
    opt_min, opt_max = opt_flops.min(), opt_flops.max()

    if sampling_max / sampling_min > 100:
        sampling_grid = np.logspace(np.log10(sampling_min), np.log10(sampling_max), 100)
        plt.xscale('log')
    else:
        sampling_grid = np.linspace(sampling_min, sampling_max, 100)

    if opt_max / opt_min > 100:
        opt_grid = np.logspace(np.log10(opt_min), np.log10(opt_max), 100)
        plt.yscale('log')
    else:
        opt_grid = np.linspace(opt_min, opt_max, 100)

    Sampling_grid, Opt_grid = np.meshgrid(sampling_grid, opt_grid)

    try:
        p_harmful_grid = griddata(
            (sampling_flops, opt_flops),
            p_harmful_vals,
            (Sampling_grid, Opt_grid),
            method='linear',
            rescale=True
        )
        if np.isnan(p_harmful_grid).sum() > 0:
            p_harmful_grid_nearest = griddata(
                (sampling_flops, opt_flops),
                p_harmful_vals,
                (Sampling_grid, Opt_grid),
                method='nearest',
                fill_value=0,
                rescale=True
            )
            fill_mask = np.isnan(p_harmful_grid)
            impossible_mask = ((Sampling_grid + opt_min) > Opt_grid) | ((Opt_grid-Sampling_grid) > opt_max-Sampling_grid)
            p_harmful_grid[fill_mask] = p_harmful_grid_nearest[fill_mask]
            p_harmful_grid[impossible_mask] = np.nan
    except Exception as e:
        if verbose:
            logging.info(f"Linear interpolation failed: {e}, trying nearest neighbor")
        p_harmful_grid = griddata(
            (sampling_flops, opt_flops),
            p_harmful_vals,
            (Sampling_grid, Opt_grid),
            method='nearest',
            fill_value=0,
            rescale=True
        )

    levels = np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 50)

    contour = plt.contourf(Opt_grid, Sampling_grid, p_harmful_grid, levels=levels,
                          cmap='viridis', extend='both')

    cbar = plt.colorbar(contour, ticks=np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 5))
    cbar.ax.set_yticklabels([f'{tick:.2f}' for tick in cbar.get_ticks()])
    if threshold is None:
        cbar.set_label(r"$\mathcal{H}$", fontsize=17)
    else:
        cbar.set_label(r"$ASR$", fontsize=17)

    total_flops = sampling_flops + opt_flops

    sort_idx = np.argsort(total_flops)
    sorted_total_flops = total_flops[sort_idx]
    sorted_sampling_flops = sampling_flops[sort_idx]
    sorted_opt_flops = opt_flops[sort_idx]
    sorted_p_harmful = p_harmful_vals[sort_idx]

    max_asr_points = []
    max_asr_seen = -np.inf

    for i in range(len(sorted_total_flops)):
        current_asr = sorted_p_harmful[i]

        if current_asr > max_asr_seen:
            max_asr_seen = current_asr
            max_asr_points.append((sorted_opt_flops[i], sorted_sampling_flops[i]))

    if max_asr_points:
        max_asr_points = np.array(max_asr_points)

        plt.plot(max_asr_points[:, 0], max_asr_points[:, 1],
                color='black', linewidth=2, linestyle="--",alpha=0.8, label="Compute Optimal Frontier")

    plt.xlabel("Total FLOPs", fontsize=14)
    plt.ylabel("Sampling FLOPs", fontsize=14)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.grid(True, alpha=0.3)


    plt.legend(loc='lower left', fontsize=13, bbox_to_anchor=(-0.1, 0.95))
    plt.tight_layout()
    os.makedirs(f"evaluate/flops_breakdown", exist_ok=True)
    if threshold is None:
        plt.savefig(f"evaluate/flops_breakdown/{title.replace(' ', '_')}.pdf", bbox_inches='tight')
    else:
        plt.savefig(f"evaluate/flops_breakdown/{title.replace(' ', '_')}_t={threshold}.pdf", bbox_inches='tight')
    plt.close()

def ridge_plot(
    sampled_data: dict[str,np.ndarray],
    model_title: str,
    cfg: dict,
    threshold: float|None = None,
):
    sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0), 'figure.figsize': (3, 3)})

    data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")])

    ridge_data = []
    def log_spaced_indices(n_cols: int, k: int = 4) -> list[int]:
        """
        Return k log-spaced column indices in [0, n_cols-1] inclusive.
        Guarantees 0 and n_cols-1 are present; deduplicates if n_cols is small.
        """
        if n_cols <= k:
            return list(range(n_cols))

        max_idx = n_cols - 1
        inner = np.geomspace(1, max_idx, num=k, dtype=int)

        idx = np.unique(np.concatenate(([0], inner, [max_idx])))
        if idx.size < k:
            extra = np.linspace(0, max_idx, num=k, dtype=int)
            idx = np.unique(np.concatenate((idx, extra)))[:k]

        return idx.tolist()
    step_idxs = log_spaced_indices(data.shape[1], 4)

    for step_idx in step_idxs:
        step_data = data[:, step_idx, :].flatten()
        for value in step_data:
            ridge_data.append({'step': f'Step {step_idx}', r"$h(Y)$": value})
    df = pd.DataFrame(ridge_data)

    unique_steps = sorted(df['step'].unique(), key=lambda x: int(x.split()[1]))
    n_steps = len(unique_steps)
    pal = sns.cubehelix_palette(n_steps, rot=-.25, light=.7)

    g = sns.FacetGrid(df, row="step", hue="step", aspect=5, height=.4, palette=pal,
                        row_order=unique_steps)

    g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, clip=(0, 1), fill=True, alpha=1, linewidth=0, zorder=1)
    g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, clip=(0, 1), color="w", lw=3, zorder=0)

    def add_mean_lines(x, **kwargs):
        ax = plt.gca()
        mean_val = np.mean(x)
        median_val = np.median(x)
        percentile_95 = np.percentile(x, 95)
        ax.axvline(median_val, color='black', linestyle='--', alpha=0.7, linewidth=1, ymax=0.5)
        ax.axvline(percentile_95, color='blue', linestyle='--', alpha=0.7, linewidth=1, ymax=0.5)
        ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.5)

    g.map(add_mean_lines, r"$h(Y)$")

    g.refline(y=0, linewidth=1, linestyle="-", color=None, clip_on=False)

    g.figure.subplots_adjust(hspace=-.4)

    g.set_titles(f"")
    g.set(yticks=[], ylabel="")
    g.despine(bottom=True, left=True)
    g.set_xlabels(r"$h(Y)$", fontsize=14)
    g.set(xlim=(0, 1))
    plt.style.use("science")

    g.figure.suptitle(f"{model_title}", fontsize=14, y=0.9, va="top")

    os.makedirs(f"evaluate/ridge_plots", exist_ok=True)
    if threshold is None:
        filename = f"evaluate/ridge_plots/{model_title}_{cfg['title_suffix']}.pdf"
    else:
        filename = f"evaluate/ridge_plots/{model_title}_{cfg['title_suffix']}_t={threshold}.pdf"
    g.figure.savefig(filename.replace(' ', '_'), bbox_inches='tight')
    plt.close(g.figure)
    num_steps = data.shape[1]
    ratios_1 = []
    ratios_2 = []

    for step_idx in range(num_steps):
        vals = data[:, step_idx, :].flatten()
        n_low_1  = np.sum((vals > 0.50) & (vals <= 1.0))
        n_high_1 = np.sum((vals > 0.10) & (vals <= 1.0))
        ratio_1 = n_low_1 / n_high_1 if n_high_1 else np.nan
        ratios_1.append(ratio_1)

        n_low_2  = np.sum((vals >= 0.10) & (vals <= 1.0))
        n_high_2 = np.sum((vals >= 0.00) & (vals <= 1.0))
        ratio_2 = n_low_2 / n_high_2 if n_high_2 else np.nan
        ratios_2.append(ratio_2)

    fig, (ax1) = plt.subplots(1, 1, figsize=(6.5, 2.65))
    plt.style.use("science")

    sns.lineplot(x=np.arange(num_steps), y=ratios_2, label=r"$P(\text{¬refusal})$", ax=ax1, marker="o" if num_steps == 1 else None)
    sns.lineplot(x=np.arange(num_steps), y=ratios_1, linestyle="--", label=r"$P(\text{harmful} \mid \text{¬refusal})$", ax=ax1, marker="x" if num_steps == 1 else None)
    ax1.yaxis.set_major_locator(MaxNLocator(nbins="auto", integer=False))
    ax1.set_xlabel("Step")
    ax1.set_ylabel("Frequency")
    ax1.set_title(f"{model_title}")

    ax1.tick_params(axis='both', which='both', top=False, right=False, left=True, bottom=True)
    ax1.grid(True, alpha=0.3)

    ax1.legend(bbox_to_anchor=(-0.3, 0.5), loc='center right')

    plt.tight_layout()
    os.makedirs(f"evaluate/ratio_plots", exist_ok=True)
    if threshold is None:
        plt.savefig(
            f"evaluate/ratio_plots/{model_title.replace(' ', '_')}_{cfg['title_suffix'].replace(' ', '_')}.pdf",
            bbox_inches="tight"
        )
    else:
        plt.savefig(
            f"evaluate/ratio_plots/{model_title.replace(' ', '_')}_{cfg['title_suffix'].replace(' ', '_')}_t={threshold}.pdf",
            bbox_inches="tight"
        )
    plt.close()

SUPPORTED_ANALYSES = {"pareto", "ridge", "flops_breakdown"}


def run_analysis(
    model: str,
    model_title: str,
    atk_name: str,
    cfg: dict,
    analysis_type: str = "pareto",
):
    if analysis_type not in SUPPORTED_ANALYSES:
        raise ValueError(f"Unsupported analysis type: {analysis_type}")

    logging.info(f"{analysis_type.title()} Analysis: {atk_name} {cfg.get('title_suffix', '')}")

    sampled_data = fetch_data(
        model,
        cfg.get("attack_override", atk_name),
        cfg["sample_params"](),
        DATASET_IDX,
        GROUP_BY,
    )

    if post := cfg.get("postprocess"):
        post(sampled_data, METRIC)

    baseline_attack = cfg.get("baseline_attack", atk_name)
    baseline_data = fetch_data(
        model,
        baseline_attack,
        cfg["baseline_params"](),
        DATASET_IDX,
        GROUP_BY,
    )

    if analysis_type == "pareto":
        pareto_plot(
            sampled_data,
            baseline_data,
            title=f"{model_title} {cfg['title_suffix']}",
            cumulative=cfg["cumulative"],
            metric=METRIC,
            threshold=0.5,
            color_scale="sqrt",
        )
        pareto_plot(
            sampled_data,
            baseline_data,
            title=f"{model_title} {cfg['title_suffix']}",
            cumulative=cfg["cumulative"],
            metric=METRIC,
            threshold=None,
            color_scale="sqrt",
        )
    elif analysis_type == "ridge":
        ridge_plot(
            sampled_data,
            model_title,
            cfg,
            threshold=None,
        )
        if baseline_data is not None:
            ridge_plot(
                baseline_data,
                model_title + " Greedy",
                cfg,
                threshold=None,
            )
    elif analysis_type == "flops_breakdown":
        flops_breakdown_plot(
            sampled_data,
            title=f"{model_title} {cfg['title_suffix']} FLOPs Breakdown",
            cumulative=cfg["cumulative"],
            metric=METRIC,
            threshold=None,
        )


def run_multi_attack_non_cumulative_pareto(model: str, model_title: str) -> None:
    logging.info(f"Multi Attack Non Cumulative Pareto Analysis: {model_title}")

    attacks_data = {}
    for atk_name, cfg in ATTACKS:
        try:
            sampled_data = fetch_data(
                model,
                cfg.get("attack_override", atk_name),
                cfg["sample_params"](),
                DATASET_IDX,
                GROUP_BY,
            )
            if post := cfg.get("postprocess"):
                post(sampled_data, METRIC)

            config_key = cfg['title_suffix']
            attacks_data[config_key] = (sampled_data, cfg)
        except Exception as e:
            logging.warning(
                f"Could not load data for {atk_name} ({cfg.get('title_suffix', 'unknown config')}): {e}"
            )

    if not attacks_data:
        logging.warning("No attacks loaded; skipping multi-attack plot")
        return

    for threshold in (None, 0.5):
        multi_attack_non_cumulative_pareto_plot(
            attacks_data=attacks_data,
            model_title=model_title,
            title=f"{model_title}",
            metric=METRIC,
            threshold=threshold,
        )
        multi_attack_non_cumulative_pareto_plot(
            attacks_data=attacks_data,
            model_title=model_title,
            title=f"{model_title}",
            metric=METRIC,
            threshold=threshold,
            target_samples=1,
        )


DEFAULT_ANALYSES = ("pareto", "ridge", "flops_breakdown", "multi_attack_non_cumulative_pareto")


def main(fail: bool = False, analysis_types=None):
    if analysis_types is None:
        analysis_types = list(DEFAULT_ANALYSES)
    for analysis_type in analysis_types:
        logging.info("\n" + "="*80)
        logging.info(f"GENERATING {analysis_type.upper().replace('_', ' ')} PLOTS")
        logging.info("="*80)

        for model_key, model_title in MODELS.items():
            logging.info(f"Model: {model_key}")
            if analysis_type == "multi_attack_non_cumulative_pareto":
                try:
                    run_multi_attack_non_cumulative_pareto(model_key, model_title)
                except Exception as e:
                    if fail:
                        raise e
                    logging.info(f"Error running {analysis_type} analysis for {model_title}: {e}")
            else:
                for atk_name, atk_cfg in ATTACKS:
                    try:
                        run_analysis(model_key, model_title, atk_name, atk_cfg, analysis_type)
                    except Exception as e:
                        if fail:
                            raise e
                        logging.info(f"Error running {analysis_type} analysis for {atk_name}, "
                            f"cfg: {atk_cfg.get('title_suffix', 'unknown')}: {e}")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Generate plots')
    parser.add_argument('--fail', action='store_true', help='Override flag to fail')
    parser.add_argument('--analysis_types', "-p", nargs='+', help='Analysis types to run')
    args = parser.parse_args()

    main(args.fail, args.analysis_types)
