# %%
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from common import force_tick_font
from matplotlib.colors import LinearSegmentedColormap

is_preview = False
plt.rc("text", usetex=not is_preview)
plt.rc("font", family="serif")
plt.rcParams["text.latex.preamble"] = r"\usepackage{times, amsmath, bm}"
plt.rcParams["font.size"] = 35
plt.rcParams["mathtext.default"] = "regular"

# label_size = 35

colors = [
  "#264653",
  "#2a9d8f",
  "#e76f51",
  "#f4a261",
  # "#e9c46a",
]
cmap = LinearSegmentedColormap.from_list("custom_gradient", colors, N=1024)

in_path = "inputs/matrices3.csv"
out_path = "outputs/matrices.pdf"

task_order = [
  "knot_unknot",
  "knot_tie_unknot",
  "knot_eq1",
]

task_legend = {
  "knot_unknot": r"\texttt{unknot}",
  "knot_tie_unknot": r"\texttt{tie}",
  "knot_eq1": r"\texttt{convert}",
}

# %%
df = pd.read_csv(in_path)
assert set(task_order) == set(df["task"].unique()), "Task order mismatch"
# Pivot and plot
tasks = df["task"].unique()
heatmaps = {}

for task in tasks:
  sub_df = df[df["task"] == task]
  pivot = sub_df.pivot_table(
    index="train_nx",
    columns="test_nx",
    values="success_rate",
    aggfunc="mean",
  ).reindex(index=[2, 3, 4], columns=[2, 3, 4])
  heatmaps[task] = pivot

# %%
fig, axes = plt.subplots(1, 3, figsize=(12, 5), constrained_layout=True)

for i, task in enumerate(task_order):
  title = task_legend[task]
  matrix = heatmaps[task]
  ax = axes[i]
  im = ax.imshow(matrix.values, vmin=0, vmax=1, cmap=cmap)
  ax.set_title(title, pad=20)
  ax.set_xlabel(r"Test \#X")
  if i == 0:
    ax.set_ylabel(r"Train \#X")
    ax.set_yticks(range(len(matrix.index)))
    ax.set_yticklabels(matrix.index)
  else:
    ax.set_yticks([])

  # majortick integer
  ax.set_xticks(range(len(matrix.columns)))
  ax.set_xticklabels(matrix.columns)

  ax.tick_params(
    axis="both",
    length=0,
    pad=10,
    which="major",
  )
  for spine in ax.spines.values():
    spine.set_visible(False)

  for i in range(matrix.shape[0]):
    for j in range(matrix.shape[1]):
      val = matrix.values[i, j]
      ax.text(
        j,
        i,
        rf"{val * 100:.0f}\%",
        ha="center",
        va="center",
        color="white",
        fontsize=30
      )
  # if i == len(task_order) - 1:
# Add colorbar to the last subplot
cbar = fig.colorbar(
  im,
  ax=ax,
  shrink=0.72,
  # label="Eval Success Rate",
  location="right",
  pad=0.03,
)
cbar.ax.tick_params(
  axis="both",
  length=0,
  pad=5,
  which="major",
)  # e.g., 10 pt
cbar.ax.set_yticks([0, 0.5, 1])
force_tick_font(
  cbar.ax,
  xfmt=lambda x: f"{x:.0f}",
  yfmt=lambda y: rf"{y * 100:.0f}\%",
)

for spine in cbar.ax.spines.values():
  spine.set_visible(False)

# %%

assert not Path(out_path).exists()
fig.savefig(
  out_path,
  bbox_inches="tight",
  pad_inches=0.1,
  dpi=300,
)
