# Color scheme for attacks
import scienceplots
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerPatch

plt.style.use(['science'])

attack_colors = {
    "AutoDAN": "#ff7f0e",  # orange
    "GCG": "#1f77b4",      # blue
    "REINFORCE-GCG": "#87CEEB",  # light blue
    "BEAST": "#2ca02c",    # green
    "PAIR": "#d62728",     # red
}

# Custom legend handler to control patch thickness explicitly
class HandlerThinPatch(HandlerPatch):
    def __init__(self, height_frac=0.18, **kwargs):
        super().__init__(**kwargs)
        self.height_frac = height_frac  # fraction of legend entry height

    def create_artists(
        self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans
    ):
        h = height * self.height_frac
        y = ydescent + (height - h) / 2.0
        x = xdescent
        w = width
        p = mpatches.Rectangle(
            (x, y), w, h,
            transform=trans,
            facecolor=orig_handle.get_facecolor(),
            edgecolor='none'
        )
        return [p]

# Create legend handles in specific order
first_row = ["AutoDAN", "BEAST", "PAIR"]
second_row = ["GCG", "REINFORCE-GCG"]

handles_first = [mpatches.Rectangle((0, 0), 1, 1, facecolor=attack_colors[a], label=a)
                 for a in first_row]
handles_second = [
    mpatches.Rectangle((0, 0), 1, 1, facecolor=attack_colors[a], label=a)
    for a in second_row
]

# Figure for legend
fig_legend = plt.figure(figsize=(4, 1.2))
ax_legend = fig_legend.add_subplot(111)
ax_legend.axis('off')

handler_map = {mpatches.Rectangle: HandlerThinPatch(height_frac=0.18)}

# First row legend
legend1 = ax_legend.legend(
    handles=handles_first,
    loc='lower left',
    ncol=3,
    fontsize=16,
    frameon=False,
    columnspacing=0.7,
    handletextpad=0.35,
    handlelength=1.5,
    handler_map=handler_map,
)

# Second row legend
legend2 = ax_legend.legend(
    handles=handles_second,
    loc='upper left',
    ncol=2,
    fontsize=16,
    frameon=False,
    columnspacing=3.075*0.91,  # aligns under first row
    handletextpad=0.35,
    handlelength=1.5,
    handler_map=handler_map,
)

ax_legend.add_artist(legend1)

plt.tight_layout()
plt.savefig(
    "evaluate/multi_attack_non_cumulative_pareto_plots/legend_base.pdf",
    bbox_inches='tight',
    dpi=300,
    transparent=True
)
plt.close()