"""
python mjc_post_collect.py --proj
python mjc_post_collect.py --summary > knotgym/knotgym/assets/configurations/summary.txt
"""

from collections import Counter
from pprint import pprint
import knotgym.specs
from tqdm import tqdm
import mediapy as media
import numpy as np
from absl import app, flags

import knotgym.utils as knot_utils
from knotgym.specs import (
  # CONFIG_BASE_DIR,
  RE_DIR,
  KnotState,
)
from knotgym.utils import colorful
from qol import safe_write
# from pathlib import Path

flags.DEFINE_bool("gc", False, "add gauss code")
flags.DEFINE_bool("proj", False, "add projection")
flags.DEFINE_bool("html", False, "add 3d plot html")
flags.DEFINE_bool("render", False, "add render.npy and render.png")
flags.DEFINE_bool("summary", False, "print summary of configurations")

# knotgym.specs.CONFIG_BASE_DIR = Path("sampled_configurations")

FLAGS = flags.FLAGS


def fn(dir: str, mj_model=None, mj_data=None, renderer=None):
  arr = np.loadtxt(dir + "xpos.txt", converters=float)

  if FLAGS.gc:
    gc = knot_utils.gauss_code(arr)
    safe_write(dir + "gc.txt", str(gc))
  if FLAGS.proj:
    knot = knot_utils._create_knot(arr)
    fig, ax = knot.plot_projection(mark_start=True, show=False)
    fig.savefig(dir + "proj.png")
  if FLAGS.html:
    from debug_state_vis import plot

    fig = plot(arr)
    fig.write_html(dir + "plot.html")
  if FLAGS.render:
    import mujoco

    qpos = np.loadtxt(dir + "qpos.txt")
    mj_data.qpos[:] = np.copy(qpos)
    mj_data.qvel[:] = np.zeros_like(mj_data.qvel)
    if mj_model.na == 0:
      mj_data.act[:] = None
    mujoco.mj_forward(mj_model, mj_data)

    renderer.update_scene(mj_data, camera="track")
    pixels = renderer.render()
    media.write_image(dir + "render.png", pixels)
    np.save(dir + "render.npy", pixels)


def main(_):
  if FLAGS.render:
    import mujoco

    mj_model = mujoco.MjModel.from_xml_path(
      "knotgym/knotgym/assets/unknot7_float.xml"
    )
    mj_model.geom_rgba[:] = np.array(
      [
        colorful(i / len(mj_model.geom_rgba))
        for i in range(len(mj_model.geom_rgba))
      ]
    )
    mj_data = mujoco.MjData(mj_model)
    renderer = mujoco.Renderer(mj_model, 512, 512)
  else:
    mj_model = None
    mj_data = None
    renderer = None

  all_dirs = [
    d
    for d in knotgym.specs.CONFIG_BASE_DIR.glob("*")
    if d.is_dir() and RE_DIR.match(d.name)
  ]
  for dir in tqdm(all_dirs, desc="Processing directories"):
    fn(str(dir) + "/", mj_model=mj_model, mj_data=mj_data, renderer=renderer)

  if FLAGS.summary:
    states = [KnotState.load(str(dir.name)) for dir in all_dirs]
    gc_freq = Counter([s.n_crossings for s in states])
    print("(#x, counts)")
    pprint(gc_freq.most_common())
    for c in sorted(gc_freq.keys()):
      _l = [s for s in states if s.n_crossings == c]
      _l = sorted(_l, key=lambda s: s.gc)
      render_path = (
        lambda s: str(knotgym.specs.CONFIG_BASE_DIR)
        + "/"
        + str(s.dir)
        + "/render.png"
      )
      _s = "\n\t".join([str(s) for s in _l])

      print(
        f"n_crossings: {c} | count: {len(_l)} | split {Counter([s.split for s in _l])}\n\t{_s}\n"
      )


if __name__ == "__main__":
  app.run(main)
