"""Experiment entry for CQL baseline."""
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = "true"
import pickle
import time
from typing import Dict, Iterator, List, Optional, Tuple

import jax
import numpy as np
import optax
import requests
import tensorflow as tf
import wandb
from absl import app, flags, logging
from acme.specs import make_environment_spec
from acme.types import Transition
from acme.utils.loggers import Logger
from dm_env import Environment
import torch
from ml_collections import ConfigDict

from internal.loggers import generate_experiment_name, logger_fn
from internal.notifier import DiscordNotif
from internal.tracer import PyTracer
from rosmo.agent.cql_discrete.actor import CQLEvalActor
from rosmo.agent.cql_discrete.learning import CQLLearner
from rosmo.agent.cql_discrete.network import CQLNetworks, make_networks
from rosmo.data.rl_unplugged import atari
from rosmo.util.env_loop_observer import (
  EvaluationLoop,
  ExtendedEnvLoopObserver,
  LearningStepObserver,
)
from rosmo.agent.world_model import dreamer
from rosmo.buffer.dataset_buffer import UniformBuffer

platform = jax.lib.xla_bridge.get_backend().platform
num_devices = jax.lib.xla_bridge.device_count()
logging.warning(f"Compute platform: {platform} with {num_devices} devices.")

notif_url = os.environ.get("NOTIF_WEBHOOK_URL", "")
notif = DiscordNotif(url=notif_url)

# ===== Flags. ===== #
FLAGS = flags.FLAGS
flags.DEFINE_string("exp_id", None, "Experiment id.", required=True)
flags.DEFINE_boolean("debug", True, "Debug run.")
flags.DEFINE_boolean("profile", False, "Profiling run.")
flags.DEFINE_integer("run_number", 1, "Run number of RLU dataset.")
flags.DEFINE_integer(
  "data_percentage", 100, "Percentage of data used for training.", 0, 100
)
flags.DEFINE_integer("seed", int(time.time()), "Random seed.")
flags.DEFINE_float("minq_weight", 1.0, "Weight for Min-Q loss.")
flags.DEFINE_boolean("use_original", False, "Use original CQL networks.")
flags.DEFINE_string("checkpoint", None, "Checkpoint to resume.")
flags.DEFINE_string("game", None, "Game name to run.", required=True)
flags.DEFINE_boolean("use_dreamer", False, "Use dreamer networks to generate data.")
flags.DEFINE_integer("model_rollout_freq", 10, "dreamer model rollout frequency")
flags.DEFINE_integer("rollout_length", 5, "dreamer model rollout length")
flags.DEFINE_integer("dreamer_init_length", 11, "dreamer init length")
flags.DEFINE_integer("batch_size", 1024, "batch size")
flags.DEFINE_integer("stack_size", 1, "stack size")
flags.DEFINE_integer("img_size", 84, "stack size")
flags.DEFINE_float("rollout_ratio", 0.5, "Weight for Min-Q loss.")

DATA_DIR = "/datasets/rl_unplugged/tensorflow_datasets"

CONFIG = {
  "num_atoms": 200,
  "channels": 64,
  "blocks_torso": 6,
  "blocks_qf": 2,
  "reduced_channels_head": 128,
  "fc_layers_qf": [128, 128],
  "output_init_scale": 0.0,
  "discount": 0.99,
  "batch_size": 512,  # 6 to match unroll_length
  "learning_rate": 1e-3,
  "adam_epsilon": 0.0003125,
  "huber_param": 1.0,
  "epsilon_eval": 0.001,
  "target_update_period": 250,
  "grad_updates_per_batch": 1,
  "log_interval": 400,
  "save_period": 10_000,
  "eval_period": 2_000,
  "evaluate_episodes": 2,
  "total_steps": 200_000,
}


# ===== Learner. ===== #
def get_learner(config, networks, data_iterator, logger) -> CQLLearner:
  """Get CQL learner."""

  learner = CQLLearner(
    networks,
    random_key=jax.random.PRNGKey(config["seed"]),
    target_update_period=config["target_update_period"],
    num_atoms=config["num_atoms"],
    minq_weight=config["minq_weight"],
    huber_param=config["huber_param"],
    batch_size=config["batch_size"],
    iterator=data_iterator,
    optimizer=optax.adam(
      learning_rate=config["learning_rate"],
      eps=config["adam_epsilon"],
    ),
    grad_updates_per_batch=config["grad_updates_per_batch"],
    logger=logger,
    log_interval=config["log_interval"],
  )
  return learner


# ===== Eval Actor-Env Loop. ===== #
def get_actor_env_eval_loop(
  config, networks: CQLLearner, environment, observers, logger
) -> Tuple[CQLEvalActor, EvaluationLoop]:
  """Get actor, env and evaluation loop."""
  environment_spec = make_environment_spec(environment)
  actor = CQLEvalActor(
    networks,
    environment_spec.actions.num_values,
    config,
  )
  eval_loop = EvaluationLoop(
    environment=environment,
    actor=actor,
    logger=logger,
    should_update=False,
    observers=observers,
  )
  return actor, eval_loop


def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]:
  """Get environment loop observers."""
  observers = []
  learning_step_ob = LearningStepObserver()
  observers.append(learning_step_ob)
  return observers


def get_environment(config) -> Environment:
  environment = atari.environment(
    game=config["game_name"], stack_size=FLAGS.stack_size, frame_size=FLAGS.img_size,
  )  # NOTE use sticky action
  return environment


def get_data_loader(config, environment, for_dreamer=False) -> Iterator:
  """Get trajectory data loader."""
  # keep the same number of transitions for learning
  environment_spec = make_environment_spec(environment)
  trajectory_length = 2 if not for_dreamer else FLAGS.dreamer_init_length
  img_size = FLAGS.img_size if not for_dreamer else 64
  dataset = atari.create_atari_ds_loader(
    game=config["game_name"],
    run_number=config["run_number"],
    data_dir=DATA_DIR,
    num_actions=environment_spec.actions.num_values,
    stack_size=FLAGS.stack_size,
    image_size=img_size,
    data_percent=config["data_percentage"],
    trajectory_length=trajectory_length,
    shuffle_num_steps=5000 if FLAGS.debug else 50000,
  )
  # ===== Dataset & Buffer ===== #
  def cql_map(steps: Dict[str, np.ndarray]) -> Transition:
    return Transition(
      observation=steps["observation"][:, 0],
      action=steps["action"][:, 0],
      reward=steps["reward"][:, 0],
      discount=steps["discount"][:, 0],
      next_observation=steps["observation"][:, 1],
    )
  def dreamer_map(data):
    image = tf.cast(data["observation"], tf.float32)
    image = image / 255.0 - 0.5  # type: ignore
    image = tf.transpose(
      image, perm=(1, 0, 4, 2, 3)
    )  # (B, T, H, W, C) => (T, B, C, H, W)
    action = tf.transpose(data["action"], perm=(1, 0))
    reset = tf.transpose(data["is_terminal"], perm=(1, 0))
    discount = tf.transpose(data["discount"], perm=(1, 0))
    return {
      "image": image,
      "action": action,
      "reset": reset,
      "discount": discount,
    }
  map_fn = dreamer_map if for_dreamer else cql_map
  if for_dreamer:
    batch_size = config["batch_size"] // 2
  else:
    if FLAGS.use_dreamer:
      batch_size = int(config["batch_size"] * (1 - FLAGS.rollout_ratio))
    else:
      batch_size = config["batch_size"]
  dataset = (
    dataset.repeat().batch(
      batch_size
    ).map(map_fn).prefetch(tf.data.AUTOTUNE)
  )
  options = tf.data.Options()
  options.threading.max_intra_op_parallelism = 1
  dataset = dataset.with_options(options)
  iterator = dataset.as_numpy_iterator()
  return iterator


# ===== Network ===== #
def get_networks(config, environment) -> CQLNetworks:
  """Get environment-specific networks."""
  environment_spec = make_environment_spec(environment)
  logging.info(environment_spec)
  networks = make_networks(
    environment_spec,
    config["channels"],
    config["num_atoms"],
    config["output_init_scale"],
    config["blocks_torso"],
    config["blocks_qf"],
    config["reduced_channels_head"],
    config["fc_layers_qf"],
    original=FLAGS.use_original,
  )
  return networks


# ===== Configurations ===== #
def get_config(game_name: str) -> Dict:
  """Get experiment configurations."""
  config = CONFIG.copy()
  config["algo"] = "CQL"
  config["game_name"] = game_name
  config["seed"] = FLAGS.seed
  config["minq_weight"] = FLAGS.minq_weight
  config["use_original"] = FLAGS.use_original
  config["run_number"] = FLAGS.run_number
  config["ckpt_number"] = FLAGS.ckpt_number
  config["data_percentage"] = FLAGS.data_percentage
  config["batch_size"] = FLAGS.batch_size
  exp_full_name = generate_experiment_name(f"{FLAGS.exp_id}_{game_name}")
  config["exp_full_name"] = exp_full_name
  # config["tpu_vm"] = requests.get(
  #   "http://metadata/computeMetadata/v1/instance/description",
  #   headers={
  #     "Metadata-Flavor": "Google"
  #   },
  # ).text
  logging.info(f"Configs: {config}")
  return config


# ===== Misc. ===== #
def get_logger_fn(
  exp_full_name: str,
  job_name: str,
  is_eval: bool = False,
  config: Optional[Dict] = None,
) -> Logger:
  """Get logger function."""
  save_data = is_eval
  return logger_fn(
    exp_name=exp_full_name,
    label=job_name,
    save_data=save_data and not (FLAGS.debug or FLAGS.profile),
    use_tb=False,
    # use_sota=not (FLAGS.debug or FLAGS.profile),
    use_wb=not (FLAGS.debug or FLAGS.profile),
    use_sota=False,
    config=config,
  )


def main(_):
  """Main program."""
  logging.info(f"Debug mode: {FLAGS.debug}")
  # ===== Profiler. ===== #
  profile_dir = "./profile"
  os.makedirs(profile_dir, exist_ok=True)
  tracer = PyTracer("./", FLAGS.exp_id, with_jax=True)

  # ===== Setup. ===== #
  cfg = get_config(FLAGS.game)

  env = get_environment(cfg)
  env_spec = make_environment_spec(env)
  networks = get_networks(cfg, env)
  dataloader = get_data_loader(cfg, env)

  # dreamer setup
  if FLAGS.use_dreamer:
    dreamer_iteator = get_data_loader(cfg, env, for_dreamer=True)
    dreamer_cfg = ConfigDict()
    dreamer_cfg.update(dreamer.DREAMER_CONFIG)
    dreamer_cfg.action_dim = env_spec.actions.num_values
    dreamer_models = []
    model_dir = "/rosmo/checkpoint/wm/"
    ckpts = [
      "new-ensemble-reset-rnn_WM-MsPacman-14kcebow",
      "new-ensemble-reset-rnn_WM-MsPacman-2n4ccnd8",
      # "new-ensemble-reset-rnn_WM-MsPacman-2sgszd1k",
      # "new-ensemble-reset-rnn_WM-MsPacman-7cno3kja",
      # "new-ensemble-reset-rnn_WM-MsPacman-kmhzan3y",
    ]
    for ckpt in ckpts:
      model_path = os.path.join(model_dir, ckpt, "train_state_290001.pt")
      dreamer_model = dreamer.Dreamer(dreamer_cfg)
      device = torch.device("cuda:0")
      dreamer_model.to(device)
      dreamer_model.load_state_dict(torch.load(model_path)["model_state_dict"])
      dreamer_models.append(dreamer_model)

    # model rollout buffer
    dreamer_buffer = UniformBuffer(
      0,
      10240,
      1,
      int(cfg["batch_size"] * FLAGS.rollout_ratio),
    )

  # ===== Resume. ===== #
  _RESUME = False
  _exp_full_name = None
  if FLAGS.checkpoint is not None:
    _RESUME = True
    ckpt_fn = FLAGS.checkpoint
    assert os.path.isfile(ckpt_fn)
    _exp_full_name = os.path.split(ckpt_fn)[-2].split("/")[-1]
    logging.warn(f"Resuming from {ckpt_fn}...")

  # ===== Essentials. ===== #
  learner = get_learner(
    cfg,
    networks,
    dataloader,
    get_logger_fn(
      _exp_full_name or cfg["exp_full_name"],
      "learner",
      config=cfg,
    ),
  )
  observers = get_env_loop_observers()
  actor, eval_loop = get_actor_env_eval_loop(
    cfg,
    networks,
    env,
    observers,
    get_logger_fn(
      _exp_full_name or cfg["exp_full_name"],
      "evaluator",
      is_eval=True,
      config=cfg,
    ),
  )
  evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"]

  # ===== Restore. ===== #
  init_step = 0
  if _RESUME:
    save_path = os.path.join("./checkpoint", _exp_full_name)
    assert os.path.isdir(save_path)
    with open(ckpt_fn, "rb") as f:
      train_state = pickle.load(f)
    fn = os.path.split(ckpt_fn)[-1].strip(".pkl")
    init_step = int(fn.split("_")[-1])
    learner.restore(train_state)
    for ob in observers:
      ob.restore(init_step + 1)
  else:
    save_path = os.path.join("./checkpoint", cfg["exp_full_name"])
    os.makedirs(save_path, exist_ok=True)

  if not (FLAGS.debug or FLAGS.profile):
    wb_name = cfg["exp_full_name"]
    wb_cfg = cfg
    if _RESUME:
      wb_name = f"{_exp_full_name}_resume"

    wandb.init(
      project="onestep",
      entity="siyili",
      name=wb_name,
      config=wb_cfg,
      sync_tensorboard=False,
    )

    notif.register(f"[Experiment started] {wb_name}")
    notif.execute()
    notif.register(f"[Experiment finished] {wb_name}")

  # ===== Training Loop. ===== #
  for i in range(init_step + 1, cfg["total_steps"]):
    data_batch = next(dataloader)
    rollout_batch = None
    if FLAGS.use_dreamer:
      # dreamer rollouts
      if i % FLAGS.model_rollout_freq == 1:
        actor.update_params(learner.save().params)
        dreamer_batch = next(dreamer_iteator)
        select_idx = np.random.choice(len(dreamer_models))
        rollout_data = dreamer_models[select_idx].combo_policy_rollout(dreamer_batch, actor, FLAGS.rollout_length, device, img_size=FLAGS.img_size)
        # with open(f'./debug/{i}.pkl', 'wb') as f:
        #   pickle.dump(rollout_data, f)
        for x in rollout_data:
          dreamer_buffer.extend(x)
      rollout_batch = dreamer_buffer.sample(int(cfg["batch_size"] * FLAGS.rollout_ratio))
      data_batch = jax.tree_map(lambda x, y: np.concatenate([x, y], axis=0), data_batch, rollout_batch)
      # train cql
    learner.step(data_batch)
    for ob in observers:
      ob.step()

    if (i + 1) % cfg["save_period"] == 0:
      with open(os.path.join(save_path, f"ckpt_{i}.pkl"), "wb") as f:
        pickle.dump(learner.save(), f)

    if FLAGS.debug or (i + 1) % cfg["eval_period"] == 0:
      actor.update_params(learner.save().params)
      eval_loop.run(evaluate_episodes)

    if FLAGS.debug:
      break

    if i == 110:
      start = time.perf_counter()
      if FLAGS.profile:
        tracer.start()
        # tf.profiler.experimental.start(profile_dir)
        logging.info("Start tracing...")
    if i == 210:
      logging.info(f"100 steps took {time.perf_counter() - start} seconds")
      if FLAGS.profile:
        tracer.stop_and_save()
        # tf.profiler.experimental.stop()
        break

  # ===== Cleanup. ===== #
  learner._logger.close()
  eval_loop._logger.close()
  del env, networks, dataloader, learner, observers, actor, eval_loop

  # ===== Notif. ===== #
  notif.execute()


if __name__ == "__main__":
  app.run(main)

"""
python experiments/run_cql.py --exp_id cql_baseline_s1_i84 --game MsPacman --nodebug --batch_size 3072 --stack_size 1 --img_size 84
python experiments/run_cql.py --exp_id cql_dreamer_s1_i84 --game MsPacman --nodebug --batch_size 1024 --stack_size 1 --img_size 84 --use_dreamer

"""