"""
python fps.py --backend=mjc --num_episodes=1600 --episode_length=20 --num_envs=16 --reset_noise_scale 0.005
"""

import json
import os
import pprint
import threading
import time

import jax
import jax.numpy as jp
import numpy as np
from absl import app, flags
from ml_collections import config_dict
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from tqdm import tqdm
import platform

import knotgym.envs as cpu_envs
import knotgym.envs_jax as jax_envs


BACKENDS = ["mjc", "mjx", "issaclab", "genesis"]
flags.DEFINE_string("mjcf", "unknot7_float", "mjcf file to start off from")
flags.DEFINE_integer("num_episodes", 10, "Number of episodes to run")
flags.DEFINE_string("task", "unknot", "Task to run")
flags.DEFINE_float("reset_noise_scale", 0.005, "Reset noise scale")
flags.DEFINE_integer("episode_length", 20, "Length of each episode")
flags.DEFINE_integer("num_envs", 64, "Number of environments (batch size)")
flags.DEFINE_string("backend", "mjc", f"Backend to use, one of {BACKENDS}")
flags.DEFINE_bool("output_pixels", False, "Whether to output pixels")
flags.DEFINE_string("logdir", "results/perf", "Output directory")
FLAGS = flags.FLAGS


class FPS:
  def __init__(self):
    self.start = time.time()
    self.total = 0
    self.lock = threading.Lock()

  def step(self, amount=1):
    with self.lock:
      self.total += amount

  def result(self, reset=True):
    with self.lock:
      now = time.time()
      fps = self.total / (now - self.start)
      if reset:
        self.start = now
        self.total = 0
      return fps


class Env:
  def __init__(self, config):
    self.config = config

  def reset(self):
    pass

  def step(self):
    pass

  def close(self):
    pass

  def sample_action(self):
    pass


class MJCEnv(Env):
  backend = "mjc"

  def __init__(self):
    config = dict(
      task=FLAGS.task,
      xml_file=FLAGS.mjcf,
      reset_noise_scale=FLAGS.reset_noise_scale,
      output_pixels=FLAGS.output_pixels,
      width=128,
      height=128,
    )

    def make_env():
      return cpu_envs.KnotEnv(**config)

    Venv = SubprocVecEnv if FLAGS.num_envs > 1 else DummyVecEnv
    self._env = Venv([make_env for _ in range(FLAGS.num_envs)])
    super().__init__(config=config)

  def reset(self):
    return self._env.reset()

  def step(self, action):
    return self._env.step(action)

  def sample_action(self):
    return 0.1 * np.ones(
      (FLAGS.num_envs,) + self._env.action_space.shape, dtype=np.float32
    )

  def close(self):
    self._env.close()


class MJXEnv(Env):
  backend = "mjx"

  def __init__(self):
    config = config_dict.create(
      xml_file=FLAGS.mjcf,
      a_frc_max=0.2,
      z_flat_threshold=0.006,
      reset_noise_scale=0.0001,  # diff
      vision=FLAGS.output_pixels,
      ctrl_dt=0.02,  # diff default to 0.1
      sim_dt=0.002,  # 0.004 gives nans  # diff
      action_repeat=1,
      episode_length=100,
      vision_config=dict(res=128, N=500),
      color=False,
    )
    assert config.ctrl_dt == config.sim_dt * 10
    super().__init__(config=config.to_dict())
    self._pipeline = jax_envs.KnotEnv(config, None)
    self._jit_step = jax.jit(jax.vmap(self._pipeline.step))
    self._jit_reset = jax.jit(jax.vmap(self._pipeline.reset))
    # debug
    self._reset_keys = jax.random.split(jax.random.PRNGKey(0), FLAGS.num_envs)
    self._reset_keys = jax.device_put(self._reset_keys)

    def _sample_action():
      return 0.1 * jp.ones((FLAGS.num_envs, self._pipeline.action_size))

    self._jit_sample_action = jax.jit(_sample_action)
    self._key = jax.random.PRNGKey(0)
    self._state = None

  def reset(self):
    # keys = jax.random.split(self._key, 1 + FLAGS.num_envs)
    # self._key = keys[0]
    # keys = jax.random.split(jax.random.PRNGKey(0), 1 + FLAGS.num_envs)
    # envs_keys = keys[1:]
    envs_keys = self._reset_keys
    self._state = self._jit_reset(envs_keys)
    return None

  def step(self, action):
    assert self._state is not None
    self._state = self._jit_step(self._state, action)
    return None

  def close(self):
    pass

  def sample_action(self):
    return self._jit_sample_action()


def main(_):
  if FLAGS.backend == "mjc":
    env = MJCEnv()
  elif FLAGS.backend == "mjx":
    env = MJXEnv()
  else:
    raise ValueError(f"Unknown backend {FLAGS.backend}")

  print("Running warmup episodes")
  env.reset()
  env.step(env.sample_action())
  env.reset()
  env.step(env.sample_action())

  print("Running benchmark episodes")
  fps = FPS()
  batch_fps = FPS()
  jax.config.update("jax_transfer_guard", "disallow")
  num_batch_episodes = FLAGS.num_episodes // FLAGS.num_envs
  print(f"Running {num_batch_episodes} episodes in batch")
  for _ in tqdm(range(num_batch_episodes)):
    _ = env.reset()
    for _ in range(FLAGS.episode_length):
      action = env.sample_action()
      env.step(action)
      fps.step(FLAGS.num_envs)
      batch_fps.step()
  env.close()

  report = {
    "env_config": env.config,
    "platform": platform.system(),
    "backend": FLAGS.backend,
    "output_pixels": FLAGS.output_pixels,
    "num_envs": FLAGS.num_envs,
    "num_episodes": FLAGS.num_episodes,
    "episode_length": FLAGS.episode_length,
    "results": {
      "fps": fps.result(reset=False),
      "batch_fps": batch_fps.result(reset=False),
    },
  }

  pprint.pprint(report)
  time_str = time.strftime("%Y%m%d-%H%M%S")
  os.makedirs(FLAGS.logdir, exist_ok=True)
  with open(os.path.join(FLAGS.logdir, time_str + ".json"), "w") as f:
    json.dump(report, f, indent=2)
  print("Wrote report to %s", f.name)


if __name__ == "__main__":
  app.run(main)
