"""
python analysis/visualize_episode.py --path tmp/2025-03-12-14-33-21-913071 --show
"""

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
import knotgym.utils as knot_utils
from knotgym.specs import KnotState, CONFIG_BASE_DIR
from absl import flags, app
import json

flags.DEFINE_string("path", None, "a dictionary containing step.npz")
flags.DEFINE_boolean("show", False, "whether to open a window")
flags.DEFINE_bool("save", False, "whether to save the animation as a gif")

FLAGS = flags.FLAGS


def animate_episode(s, e, title=""):
  xpos = s["obs/xpos"]
  xpos = np.concat([xpos, xpos[:, [0], :]], axis=1)
  assert not np.allclose(xpos[0], xpos[1])
  reward = s["reward"]
  return_ = np.cumsum(reward)
  frc = s["obs/ctrl"][:, 3:]
  frc_mag = np.linalg.norm(frc, axis=-1)
  done = s["done"]
  truncated = s["truncated"]

  def _ctrl_to_xfrc(ctrl, xpos):
    """convert ctrl to xfrc_applied"""
    frc_coord, bead_xfrc = ctrl[:3], ctrl[3:]
    # find closest bead
    frc_dist = np.linalg.norm(xpos - frc_coord, axis=-1)
    bead_index = np.argmin(frc_dist)
    return bead_index, bead_xfrc

  frc_index = np.array(
    [
      _ctrl_to_xfrc(ctrl, xpos[i, :, :3])[0]
      for i, ctrl in enumerate(s["action"])
    ]
  )

  n_rows = 2
  n_cols = 4
  width = n_cols * 3
  height = n_rows * 3
  fig = plt.figure(figsize=(width, height))
  ax_proj3d = fig.add_subplot(n_rows, n_cols, 1, projection="3d")
  ax_proj2d = fig.add_subplot(n_rows, n_cols, 5)
  ax_frc_mag = fig.add_subplot(n_rows, n_cols, 2)
  ax_frc_index = fig.add_subplot(n_rows, n_cols, 6)
  ax_reward = fig.add_subplot(n_rows, n_cols, 3)
  ax_term = fig.add_subplot(n_rows, n_cols, 7)
  ax_obs0 = fig.add_subplot(n_rows, n_cols, 4)
  ax_obsg = fig.add_subplot(n_rows, n_cols, 8)

  state0 = KnotState.load(CONFIG_BASE_DIR / e["spec"]["dir0"])
  stateg = KnotState.load(CONFIG_BASE_DIR / e["spec"]["dirg"])
  ax_obs0.imshow(state0.obs)
  ax_obsg.imshow(stateg.obs)
  ax_obs0.axis("off")
  ax_obsg.axis("off")
  ax_obs0.set_title(f"Initial obs {e['spec']['gc0']}")
  ax_obsg.set_title(f"Target obs {e['spec']['gcg']}")

  (line_frc_mag,) = ax_frc_mag.plot(frc_mag, marker="o")
  (line_frc_index,) = ax_frc_index.plot(frc_index, linestyle="None", marker="x")
  (line_reward,) = ax_reward.plot(reward)
  (line_return,) = ax_reward.plot(return_)
  (line_done,) = ax_term.plot(done)
  (line_truncated,) = ax_term.plot(truncated)

  def update(num):
    ax_proj2d.clear()
    knot = knot_utils._create_knot(xpos[num])
    knot.plot_projection(mark_start=True, show=False, fig_ax=(fig, ax_proj2d))
    ax_proj2d.set_xlabel("X")
    ax_proj2d.set_ylabel("Y")
    ax_proj2d.set_title("Top view")
    ax_proj2d.set_aspect("equal", adjustable="datalim")

    ax_proj3d.clear()
    ax_proj3d.plot(xpos[num, :, 0], xpos[num, :, 1], xpos[num, :, 2])
    ax_proj3d.quiver(
      xpos[num, frc_index[num], 0],
      xpos[num, frc_index[num], 1],
      xpos[num, frc_index[num], 2],
      frc[num, 0],
      frc[num, 1],
      frc[num, 2],
      length=frc_mag[num],
      color="orange",
    )
    ax_proj3d.set_xlabel("X")
    ax_proj3d.set_ylabel("Y")
    ax_proj3d.set_zlabel("Z")  # type: ignore
    ax_proj3d.set_title("3D view")
    ax_proj3d.set_aspect("equal", adjustable="datalim")
    ax_proj3d.set_zlim(1.0, -1.0)  # Reverse Z limits  # type: ignore
    ax_proj3d.view_init(elev=60, azim=-90 + 20)  # type: ignore

    line_frc_mag.set_data(range(num + 1), frc_mag[: num + 1])
    line_frc_index.set_data(range(num + 1), frc_index[: num + 1])
    line_reward.set_data(range(num + 1), reward[: num + 1])
    line_return.set_data(range(num + 1), return_[: num + 1])
    line_done.set_data(range(num + 1), done[: num + 1])
    line_truncated.set_data(range(num + 1), truncated[: num + 1])

  ax_frc_mag.set_xlabel("Time step")
  ax_frc_mag.legend(["Force Magnitude"], loc="upper right")

  ax_frc_index.set_xlabel("Time step")
  ax_frc_index.legend(["Force Index"], loc="upper right")

  ax_reward.set_xlabel("Time step")
  ax_reward.legend(["Reward", "Reward acc."], loc="upper right")

  ax_term.set_xlabel("Time step")
  ax_term.legend(
    ["Terminated (done)", "Truncated (timeout)"], loc="upper right"
  )

  fig.tight_layout()
  fig.subplots_adjust(top=0.9)
  fig.suptitle(title)

  N = xpos.shape[0]

  ani = animation.FuncAnimation(fig, update, N, interval=1000 / N, blit=False)  # type: ignore
  return ani


def main(_):
  folder = FLAGS.path
  stepwise = np.load(folder + "/stepwise.npz")
  with open(folder + "/episodic.json", "r") as file:
    episodic = json.load(file)
  ani = animate_episode(stepwise, episodic, title=folder)
  if FLAGS.save:
    ani.save(folder + "/animation.gif")
    print("Animation saved to", folder + "/animation.gif")
  if FLAGS.show:
    plt.show()


if __name__ == "__main__":
  app.run(main)
