# %%
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# set latex
plt.rc("text", usetex=True)
plt.rc("font", family="serif")
plt.rcParams["text.latex.preamble"] = r"\usepackage{times}"
plt.rcParams["font.size"] = 12

# %%
df_train = pd.read_csv("group2agg.csv", index_col=0)
df_train["split"] = "tr"
df_random = pd.read_csv("group2agg_random.csv", index_col=0)
df = pd.concat([df_train, df_random])


task_order = [
  "unknot_nx2",
  "unknot_nx3",
  "unknot_nx4",
  "tie_unknot_nx2",
  "tie_unknot_nx3",
  "tie_unknot_nx4",
  "eq1_nx2",
  "eq1_nx3",
  "eq1_nx4",
]

method_legend = {
  "ppo": "PPO",
  "tdmpc": "TD-MPC2",
  "dreamer": "DreamerV3",
  "random": "Random",
}

task_legend = {
  "unknot_nx2": r"\texttt{unknot}" + "\n" + r"\#X=2",
  "unknot_nx3": r"\texttt{unknot}" + "\n" + r"\#X=3",
  "unknot_nx4": r"\texttt{unknot}" + "\n" + r"\#X=4",
  "tie_unknot_nx2": r"\texttt{tie}" + "\n" + r"\#X=2",
  "tie_unknot_nx3": r"\texttt{tie}" + "\n" + r"\#X=3",
  "tie_unknot_nx4": r"\texttt{tie}" + "\n" + r"\#X=4",
  "eq1_nx2": r"\texttt{convert}" + "\n" + r"\#X=2",
  "eq1_nx3": r"\texttt{convert}" + "\n" + r"\#X=3",
  "eq1_nx4": r"\texttt{convert}" + "\n" + r"\#X=4",
}

runs = df.copy()
outpath = "outputs/task_bars.pdf"
colors = [
  # "#264653",
  "#2a9d8f",
  "#e9c46a",
  "#f4a261",
  "#e76f51",
  "#e0fbfc",
]

# %%
# Plotting

runs = df.copy()
# only keep those with "s20" in the name
# some filtering
mask = runs["task"].isin(
  [i for i in runs["task"].unique() if i.endswith("s20")]
)
runs = runs[mask]
mask = runs["split"] == "tr"
runs = runs[mask]
runs["task"] = runs["task"].str.replace("_s20", "")
task_ord = {k: i for i, k in enumerate(task_order)}

data_random = runs[runs["method"] == "random"]
data = runs[runs["method"] != "random"]

# draw on canvas
fig, ax = plt.subplots(figsize=(8, 2))


tasks = task_order
centers = np.arange(len(tasks))
width = 0.7

methods = data["method"].unique()
for index, (method, group) in enumerate(data.groupby("method")):
  group = group.sort_values("task", key=lambda x: x.map(task_ord))
  heights = group["mean"]
  ci_uppers = group["ci_upper"]
  ci_lowers = group["ci_lower"]
  pos = centers + width * (0.5 / len(methods) + index / len(methods) - 0.5)
  label = method_legend[method]
  ax.bar(pos, heights, width / len(methods), label=label, color=colors[index])
  ax.errorbar(
    pos,
    heights,
    yerr=[heights - ci_lowers, ci_uppers - heights],
    fmt="none",
    color="black",
  )

for (task, group) in data_random.groupby("task"):
  heights = group["mean"].item()
  index = task_ord[task]
  pos = np.array([centers[index] + width * 0.5, centers[index] - width * 0.5])
  heights = heights * np.ones_like(pos)
  ax.plot(
    pos,
    heights,
    linestyle="--",
    color="black",
    label=method_legend["random"] if index == 0 else None,
  )

names = [task_legend[i] for i in tasks]
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.tick_params(axis="x", which="both", width=0, length=0.8, direction="inout")
ax.set_xlim(centers[0] - 2 * (1 - width), centers[-1] + 2 * (1 - width))
ax.set_xticks(centers + 0.0)
ax.set_xticklabels(names, ha="center", fontsize=12)

ax.set_ylabel(r"Success Rate", fontsize=12)
ax.set_yticks([0.0, 0.5, 1.0])
ax.set_yticklabels([rf"{int(i * 100)}\%" for i in ax.get_yticks()], fontsize=12)

fig.tight_layout(rect=(0, 0, 1, 0.95))
fig.legend(
  loc="upper center",
  ncol=10,
  frameon=False,
  borderpad=0,
  borderaxespad=0,
  fontsize=12,
)

Path(outpath).parent.mkdir(exist_ok=True, parents=True)
fig.savefig(outpath)
print(f"Saved {outpath}")

# %%
