"""
WANDB_MODE=disabled MUJOCO_GL=egl python baselines/embodied/dreamerv3/main.py --configs=knot,size25m
"""

import json

import elements
import embodied
import embodied.core
import embodied.core.wrappers
import numpy as np

from knotgym.envs import KnotEnv


class GymEnvWrapper:
  def __init__(self, env: KnotEnv):
    self.env = env

  def reset(self):
    obs, info = self.env.reset()
    return obs, info

  def step(self, action):
    obs, rew, done, trunc, info = self.env.step(action)
    return obs, rew, (done or trunc), info


class KnotBase(embodied.Env):
  """ref: embodied.envs.Crafter"""

  def __init__(
    self,
    task: str,
    size=(480, 480),
    **kwargs,
  ):
    self.vision = True
    env = KnotEnv(
      task,
      xml_file="unknot7_float",
      height=size[0],
      width=size[1],
      output_pixels=self.vision,
      **kwargs,
    )
    # get_wrapper_attr?
    self._info_keys = [
      k for k, v in env.get_wrapper_attr("info_structure").items() if v == ()
    ]
    env = GymEnvWrapper(env)

    self.env = env
    self._done = False

  @property
  def obs_space(self):
    dtype = np.uint8 if self.vision else np.float32
    spaces = {
      "obs": elements.Space(dtype, shape=self.env.env.observation_space.shape),
      "reward": elements.Space(np.float32),
      "is_first": elements.Space(bool),
      "is_last": elements.Space(bool),
      "is_terminal": elements.Space(bool),
      "log/reward": elements.Space(np.float32),
      **{f"log/{k}": elements.Space(np.float32) for k in self._info_keys},
    }
    return spaces

  @property
  def act_space(self):
    action_space = self.env.env.action_space
    return {
      "action": elements.Space(
        action_space.dtype,
        action_space.shape,
        low=action_space.low,
        high=action_space.high,
      ),
      "reset": elements.Space(bool),
    }

  def step(self, action):
    if action["reset"] or self._done:
      self._done = False
      obs, info = self.env.reset()
      return self._obs(obs, 0.0, info, is_first=True)
    obs, reward, self._done, info = self.env.step(action["action"])
    return self._obs(
      obs,
      reward,
      info,
      is_last=self._done,
      is_terminal=self._done,
    )

  def _obs(
    self,
    obs,
    reward: float,
    info,
    is_first=False,
    is_last=False,
    is_terminal=False,
  ):
    dtype = np.uint8 if self.vision else np.float32
    return {
      "obs": np.array(obs, dtype=dtype),
      "reward": np.float32(reward),
      "is_first": is_first,
      "is_last": is_last,
      "is_terminal": is_terminal,
      "log/reward": np.float32(reward),
      **{f"log/{k}": info[k] for k in self._info_keys},
    }


class WriteStats(embodied.core.wrappers.Wrapper):
  def __init__(self, env, logdir):
    super().__init__(env)
    assert logdir is not None
    self._logdir = elements.Path(logdir)
    self._logdir.mkdir()
    self._episode = 0
    self._length = -1
    self._reward = -float("inf")

  def step(self, action):
    obs = self.env.step(action)
    if obs["is_first"]:
      self._episode += 1
      self._length = 0
      self._reward = 0
    else:
      self._length += 1
      self._reward += obs["reward"].item()
    if obs["is_last"]:
      self._write_stats()
    return obs

  def _write_stats(self):
    stats = {
      "episode": self._episode,
      "length": self._length,
      "reward": round(self._reward, 1),
    }
    filename = self._logdir / "stats.jsonl"
    lines = filename.read() if filename.exists() else ""
    lines += json.dumps(stats) + "\n"
    filename.write(lines, mode="w")
    # print(f"Wrote stats: {filename}")


class Knot(embodied.core.wrappers.Wrapper):
  def __init__(self, *args, length=None, logdir=None, **kwargs):
    if length:
      kwargs["duration"] = length
    if logdir:
      kwargs["logdir"] = logdir  # TODO: log for fewer envs
    env = KnotBase(*args, **kwargs)
    if logdir:
      env = WriteStats(env, logdir)  # extra log in embodied jsonl format
    super().__init__(env)

  def step(self, action):
    obs = self.env.step(action)
    if obs["is_last"] and obs["is_first"]:
      print("this is really weird: obs['is_last'] and obs['is_first']")
    return obs


if __name__ == "__main__":
  env = Knot(length=10, task="unknot", size=(240, 240))
  obs = env.step({"reset": True})
  print("obs_space ", env.obs_space)
  print("obs ", obs["obs"].shape)
  from matplotlib import pyplot as plt

  plt.imsave("test.png", np.uint8(obs["obs"]))
  for i in range(10):
    obs = env.step(
      {"action": 0.01 * np.random.randn(4).astype(np.float32), "reset": False}
    )
    print(obs["reward"], obs["is_last"])
  env.close()
  print("done")
