"""Train a DreamerV2 world model (in PyTorch)."""
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import random
import time
from collections import defaultdict
from typing import Dict, Iterator, Optional

import numpy as np
import scipy
import tensorflow as tf
import torch
import tree
import wandb
from absl import app, flags, logging
from acme.specs import make_environment_spec
from acme.utils.loggers import Logger
from dm_env import Environment
from ml_collections import ConfigDict
from torch import Tensor
from torch.cuda.amp import GradScaler, autocast

from internal.loggers import generate_experiment_name, logger_fn
from rosmo.agent.mbrl.morel.utils import (
  CONFIG,
  get_data_loader,
  get_environment,
)
from rosmo.agent.world_model import preprocessing
from rosmo.agent.world_model.dreamer import Dreamer
from rosmo.agent.world_model.tools import Timer, print_once, save_npz
from rosmo.data.rl_unplugged import atari

tf.config.experimental.set_visible_devices([], "GPU")

FLAGS = flags.FLAGS
flags.DEFINE_string("exp_id", None, "Experiment id.", required=True)
flags.DEFINE_integer("seed", int(time.time()), "Random seed.")
flags.DEFINE_boolean("debug", True, "Debug run.")
flags.DEFINE_boolean("profile", False, "Profiling run.")
flags.DEFINE_integer(
  "data_percentage", 100, "Percentage of data used for training.", 0, 100
)
flags.DEFINE_string("game", None, "Game name to run.", required=True)

flags.DEFINE_boolean("wb", True, "Use WB.")
flags.DEFINE_string("user", "", "Wandb user id.")
flags.DEFINE_string("project", "", "Wandb project id.")


# ===== Misc. ===== #
def get_logger_fn(
  exp_full_name: str,
  job_name: str,
  is_eval: bool = False,
  config: Optional[Dict] = None,
) -> Logger:
  """Get logger function."""
  save_data = is_eval
  return logger_fn(
    exp_name=exp_full_name,
    label=job_name,
    save_data=save_data and not (FLAGS.debug),
    use_tb=not FLAGS.wb,
    # use_sota=not (FLAGS.debug or FLAGS.profile),
    use_wb=not (FLAGS.debug) and FLAGS.wb,
    use_sota=False,
    config=config,
  )


def get_config(game_name: str) -> ConfigDict:
  config = ConfigDict()
  config.update(CONFIG)
  config.game_name = game_name
  config.image_channels = config.stack_size
  config.data_percentage = FLAGS.data_percentage
  config.benchmark = "rlu-wm"
  exp_full_name = generate_experiment_name(f"{FLAGS.exp_id}_WM-{game_name}")
  config.exp_full_name = exp_full_name

  if FLAGS.debug:
    config.batch_size = 2
    config.batch_length = 4
    config.deter_dim = 32
    config.data_percentage = 1
  return config


def main(_):
  logging.info(f"Debug mode: {FLAGS.debug}")
  random.seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  torch.manual_seed(FLAGS.seed)

  config = get_config(FLAGS.game)
  env = get_environment(config)
  dataloader = get_data_loader(config, env, use_local=False)

  env_spec = make_environment_spec(env)
  config.action_dim = env_spec.actions.num_values

  device = torch.device(config.device if torch.cuda.is_available() else "cpu")

  timers = {}

  def timer(name, verbose=False):
    if name not in timers:
      timers[name] = Timer(name, verbose)
    return timers[name]

  logger = get_logger_fn(
    config["exp_full_name"],
    "world_model",
    config=config,
  )

  # World model.
  model = Dreamer(config)
  model.to(device)
  rnn_state = model.init_state(config.batch_size * config.iwae_samples)

  # Training.
  optimizer_wm = model.init_optimizers(
    config.adam_lr,
    eps=config.adam_eps,
  )
  scaler = GradScaler(enabled=config.amp)

  dtype = np.float16 if config.amp else np.float32
  cast_tensor = lambda x: torch.tensor(x)

  def _preprocess(data):
    # Transpose is to convert [B, T] -> [T, B]
    new_data = {}
    new_data["image"] = preprocessing.to_image(data["observation"]
                                              ).astype(dtype)
    new_data["reward"] = data["reward"].transpose(1, 0).astype(np.float32)
    new_data["action"] = preprocessing.to_onehot(
      data["action"].transpose(1, 0), config.action_dim
    ).astype(dtype)
    new_data["terminal"] = data["is_terminal"].transpose(1,
                                                         0).astype(np.float32)
    new_data["reset"] = data["is_terminal"].transpose(1, 0)
    return new_data

  def _new_preprocess(data):
    data["action"] = preprocessing.to_onehot(data["action"], config.action_dim)
    return data

  if not FLAGS.debug and FLAGS.wb:
    wb_name = config["exp_full_name"]
    wandb.init(
      project=FLAGS.project,
      entity=FLAGS.user,
      name=wb_name,
      config=config.to_dict(),
      sync_tensorboard=False,
    )

  logging.info(f"Start training loop...")

  start_time = time.time()
  steps = 0
  last_time = start_time
  last_steps = steps
  metrics = defaultdict(list)
  os.makedirs(f"./checkpoint/wm/{config.exp_full_name}", exist_ok=True)
  os.makedirs(
    f"./checkpoint/wm/{config.exp_full_name}/d2_wm_dream", exist_ok=True
  )

  while True:
    with timer("total"):

      # Sample offline data.
      with timer("data"):
        batch = next(dataloader)
        batch = _new_preprocess(batch)
        batch: Dict[str, Tensor] = tree.map_structure(
          lambda x: cast_tensor(x).to(device), batch
        )

      # Forward.
      with timer("forward"):
        with autocast(enabled=config.amp):
          rnn_state = model.init_state(config.batch_size * config.iwae_samples)
          (losses, _, loss_metrics, tensors,
           dream_tensors) = model.training_step(
             batch,
             rnn_state,
             iwae_samples=config.iwae_samples,
             imag_horizon=config.imag_horizon,
             do_image_pred=steps % config.log_interval >=
             int(config.log_interval * 0.9),  # 10% of batches
             do_dream_tensors=steps % config.logbatch_interval == 1
           )

      # Backward.
      with timer("backward"):
        optimizer_wm.zero_grad()
        for loss in losses:
          scaler.scale(loss).backward()

      # Grad step.
      with timer("gradstep"):
        scaler.unscale_(optimizer_wm)
        grad_metrics = model.grad_clip(config.grad_clip)
        scaler.step(optimizer_wm)
        scaler.update()

      # Save train state.
      if steps % config.save_interval == 1:
        logging.info("Saving training state...")
        torch.save(
          {
            "model_state_dict": model.state_dict(),
            "optimizer": optimizer_wm,
            "steps": steps,
            "seed": FLAGS.seed,
          }, f"./checkpoint/wm/{config.exp_full_name}/train_state_{steps}.pt"
        )

      # Log samples.
      if dream_tensors:
        logging.info("Saving dream tensors...")
        log_batch_npz(
          batch,
          dream_tensors,
          f'{steps:07}.npz',
          subdir=f"./checkpoint/wm/{config.exp_full_name}/d2_wm_dream"
        )

      steps += 1

      # Stop.
      if steps >= config.n_steps:
        break

      # Metrics.
      if steps % config.log_interval == 0:
        for k, v in loss_metrics.items():
          if not np.isnan(v.item()):
            metrics[k].append(v.item())
        for k, v in grad_metrics.items():
          if np.isfinite(
            v.item()
          ):  # It's ok for grad norm to be inf, when using amp
            metrics[k].append(v.item())
        for k, v in timers.items():
          metrics[f'timer_{k}'].append(v.dt_ms)
        metrics = {f'train/{k}': np.mean(v) for k, v in metrics.items()}
        metrics['step'] = steps  # type: ignore
        t = time.time()
        fps = (steps - last_steps) / (t - last_time)
        metrics['train/fps'] = fps  # type: ignore
        last_time, last_steps = t, steps
        logger.write(metrics)
        metrics = defaultdict(list)
        if FLAGS.debug:
          break

  logging.info(f"Finished!")


def log_batch_npz(
  batch: Dict[str, Tensor], tensors: Dict[str, Tensor], filename: str,
  subdir: str
):

  data = dict(**batch, **tensors)
  print_once(
    f'Saving batch {subdir} (input): ',
    {k: tuple(v.shape) for k, v in data.items()}
  )
  data = prepare_batch_npz(data)
  print_once(
    f'Saving batch {subdir} (proc.): ',
    {k: tuple(v.shape) for k, v in data.items()}
  )
  save_npz(data, os.path.join(subdir, filename))


def prepare_batch_npz(data: Dict[str, Tensor], take_b=999):

  def unpreprocess(key: str, val: Tensor) -> np.ndarray:
    if take_b < val.shape[1]:
      val = val[:, :take_b]

    x = val.cpu().numpy()  # (T,B,*)
    if x.dtype in [np.float16, np.float64]:
      x = x.astype(np.float32)

    if len(x.shape) == 2:  # Scalar
      pass

    elif len(x.shape) == 3:  # 1D vector
      pass

    elif len(x.shape) == 4:  # 2D tensor - categorical image
      assert (x.dtype == np.int64 or x.dtype == np.uint8) and key.startswith('map'), \
          f'Unexpected 2D tensor: {key}: {x.shape}, {x.dtype}'

    elif len(x.shape) == 5:  # 3D tensor - image
      assert x.dtype == np.float32 and (key.startswith('image') or key.startswith('map')), \
          f'Unexpected 3D tensor: {key}: {x.shape}, {x.dtype}'

      if x.shape[-1] == x.shape[-2]:  # (T,B,C,W,W)
        x = x.transpose(0, 1, 3, 4, 2)  # => (T,B,W,W,C)
      assert x.shape[-2] == x.shape[
        -3], 'Assuming rectangular images, otherwise need to improve logic'

      if x.shape[-1] in [1, 3, 4]:
        # RGB or (stacked) grayscale
        x = ((x + 0.5) * 255.0).clip(0, 255).astype('uint8')
      elif np.allclose(x.sum(axis=-1),
                       1.0) and np.allclose(x.max(axis=-1), 1.0):
        # One-hot
        x = x.argmax(axis=-1)
      else:
        # Categorical logits
        assert key in ['map_rec', 'image_rec', 'image_pred'], \
            f'Unexpected 3D categorical logits: {key}: {x.shape}'
        x = scipy.special.softmax(x, axis=-1)

    x = x.swapaxes(0, 1)  # type: ignore  # (T,B,*) => (B,T,*)
    return x

  return {k: unpreprocess(k, v) for k, v in data.items()}


if __name__ == "__main__":
  app.run(main)
