# %%

from knotgym.specs import KnotTaskSpecFactory, KnotTask

import matplotlib.pyplot as plt


def plot_knot_states(lg):
  fig, axs = plt.subplots(5, 5, figsize=(20, 20))
  for i, ax in enumerate(axs.flat):
    ax.axis("off")
    if i < len(lg):
      ax.imshow(lg[i].obs)
      ax.set_title(f"Sample {i + 1}")
  nx = lg[0].n_crossings
  assert all([s.n_crossings == nx for s in lg])
  fig.suptitle(f"Number of Crossings: #nx={nx}", fontsize=30, y=1.02)
  fig.tight_layout()
  return fig


# %%

_ = plot_knot_states(
  KnotTaskSpecFactory(KnotTask.TIE_UNKNOT, split="tr", max_n_crossings=1).lg
)

# %%

_ = plot_knot_states(
  KnotTaskSpecFactory(KnotTask.TIE_UNKNOT, split="tr", max_n_crossings=2).lg
)

# %%

_ = plot_knot_states(
  KnotTaskSpecFactory(KnotTask.TIE_UNKNOT, split="tr", max_n_crossings=3).lg
)

# %%
_ = plot_knot_states(
  KnotTaskSpecFactory(KnotTask.TIE_UNKNOT, split="tr", max_n_crossings=4).lg
)
