"""
WANDB_MODE=disabled python train_sb3_ppo_v2.py -C.num_envs=16 -C.debug=True

## evaluation

WANDB_MODE=disabled python train_sb3_ppo_v2.py -C.do_eval -C.num_envs=32 -C.n_eval_episodes=128 -C.env_config.task_max_n_states=20 -C.seed=0 \
  -C.env_config.task_max_n_crossings=3 -C.env_config.task=tie_unknot \
  -C.load_from=results/experiments/2025-04-30-19-16-45-7798733/checkpoints/rl_model_1088000_steps.zip

"""

import dataclasses
import json
import os
import pprint
import re
import sys
import shutil
from datetime import datetime
from typing import Tuple

import gymnasium as gym
import yaml
from absl import app, logging
from gymnasium.wrappers import RecordVideo
from ml_collections import ConfigDict, config_flags
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
  BaseCallback,
  CallbackList,
  CheckpointCallback,
  EvalCallback,
)
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import (
  DummyVecEnv,
  SubprocVecEnv,
  VecTransposeImage,
)

import knotgym  # noqa: F401
import knotgym.mjcf
import wandb
from knotgym.utils import check_cython
from qol import safe_write
from wandb.integration.sb3 import WandbCallback


class DetailedMonitor(gym.Wrapper):
  """record more statistics, ref Monitor
  functionalities:
      cache extra statistics
      sum/average and return on termination/truncation

  ** notes:
      must wraps inside of Monitor (via ep_info_buffer)
      must use with DetailedMonitorCallback (to log)
  """

  def __init__(
    self,
    env: gym.Env,
    prefix: str,
    detailed_info_keywords: Tuple[str, ...] = (),
  ) -> None:
    super().__init__(env)
    self.prefix = prefix  # used by DetailedMonitorCallback
    self.info_kws = detailed_info_keywords
    self.episode = {k: [] for k in detailed_info_keywords}  # cache

    # allows results_writer and info["episode"] tracking via outer Monitor
    self.ep_kwargs = [k for k in self.info_kws if k.startswith("ep_")]
    self.step_kwargs = [k for k in self.info_kws if not k.startswith("ep_")]

    temp = [
      [self.prefix + k + suffix for suffix in ("_Σ", "_μ")]
      for k in self.step_kwargs
    ]
    temp += [[self.prefix + k] for k in self.ep_kwargs]
    # flatten 2d nested list
    temp = [i for sublist in temp for i in sublist]
    self.new_detailed_info_keywords = tuple(temp)

  def step(self, action):
    """ref Monitor.step"""
    observation, reward, terminated, truncated, info = self.env.step(action)
    for k in self.info_kws:
      self.episode[k].append(info[k])
    # aggregate
    if terminated or truncated:
      for k in self.step_kwargs:
        su = sum(self.episode[k])
        me = safe_mean(self.episode[k])
        key = self.prefix + k
        info[key + "_Σ"] = round(su, 6)
        info[key + "_μ"] = round(me, 6)
      for k in self.ep_kwargs:
        assert k.startswith("ep_max/")
        mx = max(self.episode[k]) if self.episode[k] else float("nan")
        key = self.prefix + k
        info[key] = round(mx, 6)
    return observation, reward, terminated, truncated, info

  def reset(self, **kwargs):
    """ref Monitor._on_reset"""
    self.episode = {k: [] for k in self.info_kws}
    return self.env.reset(**kwargs)


class DetailedMonitorCallback(BaseCallback):
  """log extra rollout/ep_x_ like rollout/ep_rew_mean"""

  def __init__(self, target_prefix, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.target_prefix = target_prefix

  def _on_rollout_end(self):
    """ref rollout/ep_rew_mean OnPolicyAlgorithm._dump_logs"""
    buffer = self.model.ep_info_buffer
    if len(buffer) == 0 or len(buffer[0]) == 0:
      return
    d_keys = [k for k in buffer[0].keys() if k.startswith(self.target_prefix)]
    for k in d_keys:
      # average across rollout episodes
      me = safe_mean([ep_info[k] for ep_info in buffer])
      self.logger.record(f"rollout/ep_{k}", me)

  def _on_step(self):
    return True


@dataclasses.dataclass
class MjcfConfig:
  template_xml_file: str = "unknot7_float"
  data_file: str = "initial.txt"
  num_subcables: int = 7
  num_beads: int = 100

  def __post_init__(self):
    raise DeprecationWarning


@dataclasses.dataclass
class KnotEnvConfig:
  task: str = "tie_unknot"  # one in specs.parse
  task_max_n_states: int = 1024  # cap for easier training
  task_max_n_crossings: int = 100  # cap for easier training  # OVERRIDE
  task_subset_seed: int = -1  # -1 mean no shuffling when subsetting 10 targets
  xml_file: str = "unknot7_float"
  reset_noise_scale: float = 0.015  # >0, the scale of noise for reset
  frame_skip: int = 24  # model.opt.timestep = 0.01 -> dt = 0.1 sec/frame
  a_frc_max: float = 0.2  # >0, bounding the action space
  r_scale_dt_cross: float = 0.0  # <0, punishes adding a cross
  r_scale_zero_cross: float = 5.0  # >0, reward based on number of crossings
  r_scale_gc: float = 5.0  # >0, reward based on gauss code
  r_gc_allow_flipped_or_mirrored: bool = True  # simpler criteria to meet
  normalize_obs: bool = False  # skips for images
  normalize_action: bool = True  # assume outputs of policy [-1, 1]
  output_pixels: bool = True  # whether to output pixels  # vision or not
  done_after: int = 1  # done after n steps of continuous is_success
  height: int = 128
  width: int = 128
  duration: int = 50  # max episode steps
  r_scale_timeout_punish: float = -5.0  # <0, punish for timeout


@dataclasses.dataclass
class Config:
  # experiment
  note: str = ""
  seed: int = 0
  debug: bool = False
  # one of train/eval
  do_train: bool = True
  do_eval: bool = False

  # policy
  policy_type: str = "CnnPolicy"
  learning_rate: float = 1e-5
  ent_coef: float = 0.0
  batch_size: int = 64  # ppo mini batch size
  log_std_init: float = 0.0  # set with consideration of the action space
  features_dim: int = 768  # NatureCNN in sb3.common.policies
  total_timesteps: int = 1_100_000
  record_every_n_episode: int = 20
  num_envs: int = 8
  num_eval_envs: int = 8
  n_eval_episodes: int = 128

  # env
  env_name: str = "knotgym/Unknot-v0"
  env_config: KnotEnvConfig = dataclasses.field(default_factory=KnotEnvConfig)

  # policy init or eval
  load_from: str = ""

  def __post_init__(self):
    if not isinstance(self.env_config, KnotEnvConfig):
      self.env_config = KnotEnvConfig(**self.env_config)
    if self.env_config.output_pixels and self.policy_type == "MlpPolicy":
      logging.error("output_pixels=True but policy_type=MlpPolicy")
      logging.error("consider changing policy_type to CnnPolicy")
      logging.error("features_dim is ignored too")
    if self.debug:
      logging.info("Debug mode - overwriting config")
      self.batch_size = 4
      self.total_timesteps = 10_000
      self.env_config.duration = 20  # 10fps -> 2 sec
      self.record_every_n_episode = 1
      self.num_envs = 1
      self.num_eval_envs = 1
      self.n_eval_episodes = 2
    if self.do_eval:
      self.do_train = False
      logging.info("Eval mode")


_CONFIG = config_flags.DEFINE_config_dict(
  "C", ConfigDict(dataclasses.asdict(Config())), "Config dict", lock_config=True
)


def get_unique_output_dir(base_folder: str, slurm_job_id: str) -> str:
  time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
  output_dir = f"{base_folder}/{time_str}-{slurm_job_id}"
  os.mkdir(output_dir)
  return output_dir


def parse_slurm_job_id(load_from) -> str:
  # given: results/experiments/2025-04-30-19-16-45-7798733/checkpoints/rl_model_1088000_steps.zip
  # 7798733
  match = re.search(
    r"results/experiments/\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-(\d+)/", load_from
  )
  return match.group(1)


def io_maybe_resume(
  base_folder: str, slurm_job_id: str, spec_load_from: str
) -> Tuple[str, bool, str, str]:
  """create output_dir and resume from checkpoint if exists

  Args:
      base_folder (str): usually "./results/experiments"
      slurm_job_id (str): global variable for slurm job id
      spec_load_from (str): takes lower priority than ckpts found in output_dir

  Returns:
      Tuple[str, bool, str]: output_dir, resume, maybe_run_id, maybe_load_from
          resume -> continue a past experiment (incl. tracking/folder)
          maybe_run_id -> wandb run id for resuming from a preempted job
          load_from -> initialize by loading from a previous checkpoint
  """
  spec_load_from = spec_load_from if spec_load_from else None

  output_dir = [d for d in os.listdir(base_folder) if d.endswith(slurm_job_id)]
  if slurm_job_id == "local" or len(output_dir) == 0:
    output_dir = get_unique_output_dir(base_folder, slurm_job_id)
    return output_dir, False, None, spec_load_from
  assert len(output_dir) == 1, (
    f"multiple dirs found: {output_dir} with {slurm_job_id}"
  )
  output_dir = os.path.join(base_folder, output_dir[0])
  with open(os.path.join(output_dir, "run_id.txt"), "r") as f:
    run_id = f.read().strip()
  logging.info(f"Resuming from {output_dir}, run_id: {run_id}")
  checkpoint_dir = os.path.join(output_dir, "checkpoints")
  if os.path.exists(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir)]
    # "rl_model_8000_steps.zip" -> "8000"
    checkpoints.sort(key=lambda x: int(x.split("_")[-2]))
  else:
    logging.warning(f"Folders {checkpoint_dir} not found")
    checkpoints = []
  if len(checkpoints) > 0:
    load_from = os.path.join(checkpoint_dir, checkpoints[-1])
    logging.info(f"Resuming from {load_from}")
    if spec_load_from is not None:
      logging.warning(f"Ignoring Config.load_from: {spec_load_from}")
  else:
    logging.warning(f"No checkpoints found in {checkpoint_dir}")
    logging.warning(f"Fallback to Config.load_from: {spec_load_from}")
    load_from = spec_load_from
  return output_dir, True, run_id, load_from


def build_env(config: Config, output_dir, resume: bool, split: str):
  output_dir = os.path.join(output_dir, split)
  os.makedirs(output_dir, exist_ok=True)
  xml_file = knotgym.mjcf.load_xml_from_asset(config.env_config.xml_file)
  xml_file = os.path.abspath(xml_file)
  shutil.copy(xml_file, os.path.join(output_dir, "filled_mjcf.xml"))

  # create train envs
  def make_env(rank: int):
    add_recorder_at = 1
    add_logdir_at = 1
    if add_logdir_at == rank:
      logdir = os.path.join(output_dir, "episodes", f"{rank:>04}")
    else:
      logdir = None

    def _init():
      env_config = dataclasses.asdict(config.env_config)
      env = gym.make(
        config.env_name, split=split, logfreq=25, logdir=logdir, **env_config
      )
      info_keywords = [
        k for k, v in env.get_wrapper_attr("info_structure").items() if v == ()
      ]
      env.action_space.seed(config.seed + rank)
      prefix = f"x{split}/"
      # reduces to strings like xtr/, xea/, xeb/ (xeb/ is not implemented)
      # xtr stands for extended training stats
      # xea stands for extended type-a eval stats
      env = DetailedMonitor(
        env, prefix=prefix, detailed_info_keywords=tuple(info_keywords)
      )
      env = Monitor(
        env,
        info_keywords=env.new_detailed_info_keywords + ("is_success",),
        filename=os.path.join(output_dir, "monitor", f"{rank:>04}_monitor.csv"),
        override_existing=not resume,
      )
      if add_recorder_at == rank:

        def ep_record_trigger(x):
          return x % config.record_every_n_episode == 0

        env = RecordVideo(
          env,
          os.path.join(output_dir, "videos"),
          episode_trigger=ep_record_trigger,
          name_prefix=f"rank_{rank:04}",
        )
      return env

    return _init

  num_envs = config.num_envs if split == "tr" else config.num_eval_envs
  if num_envs <= 1:
    env = DummyVecEnv([make_env(i + 1) for i in range(num_envs)])
  else:
    env = SubprocVecEnv([make_env(i + 1) for i in range(num_envs)])
    # see sb3.common.base_class.BaseAlgorithm._wrap_env
  env = VecTransposeImage(env)
  env.seed(config.seed)
  return env


def sanity_checks():
  if not check_cython():
    logging.error("no c extension found for pyknotid")


def main(_):
  logging.info("Raw config:\n" + pprint.pformat(_CONFIG.value))
  config = Config(**_CONFIG.value)
  logging.info("Final config:\n" + pprint.pformat(config))
  slurm_job_id = os.environ.get("SLURM_JOB_ID", "local")

  sanity_checks()

  if config.do_train:
    output_dir, resume, maybe_run_id, maybe_load_from = io_maybe_resume(
      "./results/experiments", slurm_job_id, config.load_from
    )
  elif config.do_eval:
    output_dir = get_unique_output_dir("./results/evaluations", slurm_job_id)
    resume = False
    maybe_run_id = None
    maybe_load_from = config.load_from
    assert maybe_load_from is not None, "must provide load_from for eval"
  else:
    raise ValueError("must do at least one of train/eval")
  safe_write(
    os.path.join(output_dir, "config.yaml"),
    yaml.dump(dataclasses.asdict(config)),
  )

  if config.do_train:
    env = build_env(config, output_dir, resume, split="tr")
    eval_env = build_env(config, output_dir, resume, split="ea")
  else:
    assert config.do_eval
    env = build_env(config, output_dir, resume, split="ea")
    eval_env = None

  # create ppo model
  if maybe_load_from is not None:
    logging.info(f"loading from {maybe_load_from}")
    model = PPO.load(
      maybe_load_from,
      env=env,
      print_system_info=True,
    )
  else:
    model = PPO(
      config.policy_type,
      env,
      learning_rate=config.learning_rate,
      ent_coef=config.ent_coef,
      n_steps=config.env_config.duration,
      batch_size=config.batch_size,
      verbose=1,
      tensorboard_log=os.path.join(output_dir, "runs"),
      seed=config.seed,
      # use_sde=True,
      policy_kwargs=dict(
        #   squash_output=True,
        log_std_init=config.log_std_init,
        features_extractor_kwargs=dict(
          features_dim=config.features_dim,
        ),
      ),
    )

  total_params = sum(
    p.numel() for p in model.policy.parameters() if p.requires_grad
  )
  logging.info(f"Total trainable parameters: {total_params}")

  if config.do_train:
    # wandb init and misc io
    run = wandb.init(
      project="knots-dev-ppo",
      notes=config.note if config.note != "" else None,
      config=dataclasses.asdict(config) | {"slurm_job_id": slurm_job_id},
      sync_tensorboard=True,
      monitor_gym=True,  # log video?
      save_code=True,
      id=maybe_run_id,
      resume="must" if resume else None,
    )
    safe_write(os.path.join(output_dir, "run_id.txt"), run.id)
    safe_write(os.path.join(output_dir, "slurm_job_id.txt"), slurm_job_id)

    # set up callbacks (even if resume)
    cbs = []
    cbs.append(
      DetailedMonitorCallback(
        target_prefix=env.get_attr("prefix", indices=[0])[0]
      )
    )
    cbs.append(
      CheckpointCallback(
        save_freq=1000,
        save_path=os.path.join(output_dir, "checkpoints"),
      )
    )
    cbs.append(WandbCallback(verbose=2))
    if eval_env is not None:
      cbs.append(
        EvalCallback(
          eval_env=eval_env,
          eval_freq=1000,
          # every (n_envs * eval_freq) global time steps, similar to save_freq
          log_path=output_dir,
        )
      )
    cbs = CallbackList(cbs)

    model.learn(
      total_timesteps=config.total_timesteps,
      callback=cbs,
      progress_bar=True,
      tb_log_name="ppo",
      reset_num_timesteps=not resume,
    )
    run.finish()

  if config.do_eval:
    safe_write(os.path.join(output_dir, "slurm_job_id.txt"), slurm_job_id)
    safe_write(os.path.join(output_dir, "load_from.txt"), config.load_from)
    episode_rewards, episode_lengths = evaluate_policy(
      model,
      model.get_env(),
      n_eval_episodes=config.n_eval_episodes,
      deterministic=True,
      render=False,
      return_episode_rewards=True,
    )
    logging.info(f"Eval episode_rewards: {episode_rewards}")
    logging.info(f"Eval episode_lengths: {episode_lengths}")

    report = dict(
      # checkpointing
      eval_slurm_job_id=parse_slurm_job_id(config.load_from),
      load_from=config.load_from,
      command=" ".join(sys.argv),
      # tasks
      task=config.env_config.task,
      num_envs=config.num_eval_envs,
      task_max_n_crossings=config.env_config.task_max_n_crossings,
      task_max_n_states=config.env_config.task_max_n_states,
      # results
      episode_success_rate=safe_mean(
        [1.0 if rew > 0 else 0.0 for rew in episode_rewards]
      ),
      episode_rewards_mean=safe_mean(episode_rewards),
      episode_lengths_mean=safe_mean(episode_lengths),
      n_eval_episodes=len(episode_rewards),  # actually run
      episode_rewards=episode_rewards,
      episode_lengths=episode_lengths,
      # aux
      timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
      seed=config.seed,
      env_config=dataclasses.asdict(config.env_config),
    )
    report_path = os.path.join(output_dir, "eval_report.json")
    safe_write(report_path, json.dumps(report, indent=2))
    logging.info(f"Eval report saved to {report_path}")


if __name__ == "__main__":
  app.run(main)
