"""

python mjc_sample.py --mjcf knotgym/knotgym/assets/unknot7_float.xml --record

# mjpython mjc_interact.py
# LINE_PROFILER=1 python mjc_interact.py --bare --mjcf ...
# testspeed ...filled_mjcf.xml 10000 5 0.01 5
#
# # on macos
# MUJOCO_GL=glfw mjpython mjc_interact.py --mjcf knotgym/knotgym/assets/unknot7_float.xml --random
"""

import os
from datetime import datetime

import mediapy as media
import mujoco
import numpy as np
from absl import app, flags, logging
from gymnasium.vector import AsyncVectorEnv
from gymnasium.wrappers import RecordVideo
from tqdm import tqdm

import knotgym.utils as knot_utils
import qol
from knotgym.envs import KnotEnvBase
from knotgym.utils import _create_knot

flags.DEFINE_bool("record", True, "record video")
flags.DEFINE_integer("num_envs", 8, "number of environments")
flags.DEFINE_integer("num_steps", 5000, "number of steps per env")
flags.DEFINE_integer("sample_every_n_steps", 25, "sample every n steps")
flags.DEFINE_float(
  "spreadout_threshold",
  0.03,
  "minimum distance between two crossings to be considered spread out",
)
# flags.DEFINE_string("output_dir", "results", "output directory")


FLAGS = flags.FLAGS


def configure_display():
  """misc display configurations"""
  np.set_printoptions(precision=2, suppress=True)
  logging.get_absl_handler().setFormatter(None)


def crossing_to_pos(crossing, points):
  # ref: see SpaceCurve.plot_projection)
  x, *_ = crossing
  xint = int(x)
  r = points[xint]
  dr = points[(xint + 1) % len(points)] - r
  return r + (x - xint) * dr


def min_pairwise_dist(points: np.ndarray) -> bool:
  # Compute pairwise distances
  diff = points[:, np.newaxis, :] - points[np.newaxis, :, :]  # shape (n, n, 2)
  dist_squared = np.sum(diff**2, axis=-1)  # shape (n, n)
  # Ignore self-distances by setting diagonal to infinity
  np.fill_diagonal(dist_squared, np.inf)
  return np.min(np.sqrt(dist_squared))


def accept(xpos: np.ndarray, threshold: float) -> bool:
  knot = _create_knot(xpos)
  raw_crossings = knot.raw_crossings()
  poses = np.array(
    [crossing_to_pos(crossing, knot.points) for crossing in raw_crossings]
  )
  if len(poses) == 0:
    return False
  half_crossings = poses[::2, :2]
  if len(half_crossings) not in (4,):  # handle
    return False
  return min_pairwise_dist(half_crossings) >= threshold


class ConfigurationSaver:
  def __init__(self, m):
    self._dir = "sampled_configurations"  # TODO: epath
    # self._dir = FLAGS.output_dir
    self._n_decimals = 4
    self._renderer = mujoco.Renderer(m, height=512, width=512)

  def on_step(self, mj_model, mj_data, *args, **kwargs):
    time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
    _dir = os.path.join(self._dir, time_str)
    os.mkdir(_dir)

    _fmt = f"%.{self._n_decimals}f"
    np.savetxt(os.path.join(_dir, "qpos.txt"), mj_data.qpos, fmt=_fmt)
    np.savetxt(os.path.join(_dir, "xpos.txt"), mj_data.xpos[1:], fmt=_fmt)

    self._renderer.update_scene(mj_data, camera="track")
    pixels = self._renderer.render()
    media.write_image(os.path.join(_dir, "render.png"), pixels)
    np.save(os.path.join(_dir, "render.npy"), pixels)

    gc = knot_utils.gauss_code(mj_data.xpos[1:])
    qol.safe_write(os.path.join(_dir, "gc.txt"), str(gc))

    split = "train" if np.random.rand() < 0.5 else "val"
    qol.safe_write(os.path.join(_dir, "split.txt"), split)

    logging.info("saved to %s", _dir)


class KnotEnvWithSave(KnotEnvBase):
  def __init__(
    self, sample_every_n_steps: int, spreadout_threshold: float, *args, **kwargs
  ):
    super().__init__(*args, **kwargs)
    self._configuration_saver = ConfigurationSaver(self.model)
    self._spreadout_threshold = spreadout_threshold

    self._step = 0
    self.should_sample = (
      lambda i: i > 1 and i % sample_every_n_steps == 0
    )  # TODO

  def step(self, action):
    obs, reward, term, trunc, info = super().step(action)
    self._step += 1
    threshold = self._spreadout_threshold
    if self.should_sample(self._step) and accept(info["obs/xpos"], threshold):
      self._configuration_saver.on_step(self.model, self.data)
      print("saved configuration")
    return obs, reward, term, trunc, info

  def reset(self, **kwargs):
    obs, info = super().reset(**kwargs)
    self._step = 0
    return obs, info


def make_env(rank: int):
  sample_every_n_steps = FLAGS.sample_every_n_steps
  spreadout_threshold = FLAGS.spreadout_threshold

  def _init():
    env = KnotEnvWithSave(
      sample_every_n_steps=sample_every_n_steps,
      spreadout_threshold=spreadout_threshold,
      task="unknot",
      # xml_file="unknot7_float",
      output_pixels=False,
      split="train",  # TODO: overwrite.
      render_both=False,
      height=512,
      width=512,
    )
    if rank == 0:
      env = RecordVideo(
        env,
        os.path.join("sampled_videos"),
      )
    return env

  return _init


def main(_):
  configure_display()
  envs = AsyncVectorEnv(
    [make_env(i) for i in range(FLAGS.num_envs)],
    autoreset_mode="SameStep",
  )

  envs = make_env(0)()

  envs.reset()

  for step in tqdm(range(FLAGS.num_steps)):
    action = envs.action_space.sample()
    _ = envs.step(action)
  envs.close()

  logging.info("done")


if __name__ == "__main__":
  app.run(main)
