# %%

# plot training curves of different tasks

import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from common import style_tr


plt.rc("text", usetex=True)
plt.rc("font", family="serif")
plt.rcParams["text.latex.preamble"] = r"\usepackage{times}"
plt.rcParams["font.size"] = 30

metrics = {
  "dreamer": "seed_rolling/episode/score",  # this is actually success rate too
  "tdmpc": "seed_rolling/train/episode_success",
  "ppo": "seed_agg/rollout/success_rate",
}

agg_dirs = {k: Path(f"{k}/{v}") for k, v in metrics.items()}


task_order = [
  "unknot_nx2",
  "unknot_nx3",
  "unknot_nx4",
  "tie_unknot_nx2",
  "tie_unknot_nx3",
  "tie_unknot_nx4",
  "eq1_nx2",
  "eq1_nx3",
  "eq1_nx4",
]

method_legend = {
  "ppo": "PPO",
  "tdmpc": "TD-MPC2",
  "dreamer": "DreamerV3",
  "random": "Random",
}

method_order = ["dreamer", "ppo", "tdmpc"]

task_legend = {
  "unknot_nx2": r"\texttt{unknot}" + "\n" + r"\#X=2",
  "unknot_nx3": r"\texttt{unknot}" + "\n" + r"\#X=3",
  "unknot_nx4": r"\texttt{unknot}" + "\n" + r"\#X=4",
  "tie_unknot_nx2": r"\texttt{tie}" + "\n" + r"\#X=2",
  "tie_unknot_nx3": r"\texttt{tie}" + "\n" + r"\#X=3",
  "tie_unknot_nx4": r"\texttt{tie}" + "\n" + r"\#X=4",
  "eq1_nx2": r"\texttt{convert}" + "\n" + r"\#X=2",
  "eq1_nx3": r"\texttt{convert}" + "\n" + r"\#X=3",
  "eq1_nx4": r"\texttt{convert}" + "\n" + r"\#X=4",
}

outpath = "outputs/task_train.pdf"
colors = [
  # "#264653",
  "#2a9d8f",
  "#e9c46a",
  "#f4a261",
  "#e76f51",
  "#e0fbfc",
]
df_groups = pd.read_csv("inputs/task_groups.csv")
# df_groups = df_groups[df_groups["task_max_n_crossings"] == 20]

# %%


def plot_tr(ax, traj, color, label):
  ax.plot(traj["step"], traj["mean"], label=label, color=color, linewidth=3)
  ax.fill_between(
    traj["step"], traj["ci_lower"], traj["ci_upper"], alpha=0.2, color=color
  )

def smoothen_traj(traj, window=200):
  traj["mean"] = traj["mean"].rolling(window).mean()
  traj["ci_lower"] = traj["ci_lower"].rolling(window).mean()
  traj["ci_upper"] = traj["ci_upper"].rolling(window).mean()
  return traj

# %%

new_colors = [
  "#2a9d8f",
  "#f4a261",
  "#264653",
  "#e9c46a",
  "#e76f51",
  "#e0fbfc",
]
new_method_order = [
  "dreamer",
  "tdmpc",
  "ppo",
]

df_groups2 = df_groups[df_groups["task_max_n_states"] == 20]
assert len(df_groups2) == 9

fig, axes = plt.subplots(3, 3, figsize=(20, 12))
for i, row in enumerate(df_groups2.iterrows()):
  row = row[1]
  task = row["task"]
  nx = row["task_max_n_crossings"]
  nn = row["task_max_n_states"]
  trajs = {
    m: pd.read_csv(agg_dirs[m] / f"{task}_nx{nx}_s{nn}" / "agg.csv")
    for m in method_order
  }
  ax = axes[i // 3, i % 3]
  for m in new_method_order:
    traj = trajs[m]
    if m != "ppo":
      traj = smoothen_traj(traj)
    plot_tr(
      ax,
      traj,
      color=new_colors[new_method_order.index(m)],
      label=method_legend[m],
    )
  style_tr(ax)
  # last row
  if i // 3 != 2:
    ax.set_xticklabels([])
  # first column
  if i % 3 != 0:
    ax.set_yticklabels([])
  if i == 8:
    ax.legend(
      loc="right",
      bbox_to_anchor=(0.8, 0.5),
      # bbox_to_anchor=(0.5, 1.35),
      ncol=1,
      frameon=False,
      borderpad=0,
      borderaxespad=0,
      handlelength=1.2,
      # handletextpad=0.5,
    )
  title = task_legend[f"{task}_nx{nx}"]
  title = title.replace("\n", r"\ ")
  ax.set_title(title, fontsize=30)

outpath = "outputs/task_train.pdf"
# assert not Path(outpath).exists()
fig.savefig(outpath, bbox_inches="tight")

# %%
row = df_groups.iloc[0]


subtask = (
  f"{row['task']}_nx{row['task_max_n_crossings']}_s{row['task_max_n_states']}"
)

trajs = {m: pd.read_csv(dd / subtask / "agg.csv") for m, dd in agg_dirs.items()}

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
for m, traj in trajs.items():
  plot_tr(
    ax,
    traj,
    color=colors[method_order.index(m)],
    label=method_legend[m],
  )

style_tr(ax)


# %%


trajs = {
  x: pd.read_csv(agg_dirs["dreamer"] / f"tie_unknot_nx3_s{x}" / "agg.csv")
  for x in (1, 5, 10, 20)
}





for x, traj in trajs.items():
  trajs[x] = smoothen_traj(traj)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
for i, (x, traj) in enumerate(trajs.items()):
  plot_tr(
    ax,
    traj,
    color=colors[i],
    label=str(x),
  )
style_tr(ax)
ax.legend(
  loc="right",
  bbox_to_anchor=(1.3, 0.5),
  # bbox_to_anchor=(0.5, 1.35),
  ncol=1,
  frameon=False,
  borderpad=0,
  borderaxespad=-0.2,
  handlelength=1.2,
  handletextpad=0.5,
)
ax.set_xlabel(r"Steps", fontsize=30)
ax.set_ylabel(r"Success Rate", fontsize=30)
outpath = "outputs/nstates_tie_unknot_nx3.pdf"
# assert not Path(outpath).exists()
fig.savefig(outpath, bbox_inches="tight")
