import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from typing import Iterator, List

import tensorflow as tf
import torch
import tree
import wandb
from absl import app, flags, logging
from acme.specs import make_environment_spec
from ml_collections import ConfigDict
from torch import Tensor
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from internal.loggers import generate_experiment_name, logger_fn
from rosmo.agent.world_model import preprocessing
from rosmo.agent.world_model.dreamer import Dreamer
from rosmo.agent.world_model.tools import Timer, print_once, save_npz
from rosmo.data.rl_unplugged import atari

tf.config.experimental.set_visible_devices([], "GPU")
from rosmo.agent.world_model.dreamer import Dreamer
from rosmo.data.rl_unplugged import atari

DATA_DIR = "/datasets/rl_unplugged/tensorflow_datasets"

CONFIG = {
  # Job
  "n_steps": 1_000_000,
  "n_env_steps": 1_000_000_000,
  "log_interval": 100,
  "logbatch_interval": 1000,
  "save_interval": 10_000,
  "eval_interval": 2000,
  "data_workers": 4,
  "enable_profiler": False,
  "verbose": False,
  # Features
  "image_key": "image",
  "image_size": 64,
  "stack_size": 1,
  "image_categorical": False,
  # Training
  "reset_interval": 200,
  "iwae_samples": 1,
  "kl_balance": 0.8,
  "kl_weight": 0.1,  # For atari.
  "image_weight": 1.0,
  "vecobs_weight": 1.0,
  "reward_weight": 1.0,
  "terminal_weight": 1.0,
  "adam_lr": 3.0e-4,
  "adam_lr_actor": 1.0e-4,
  "adam_lr_critic": 1.0e-4,
  "adam_eps": 1.0e-5,
  "keep_state": True,
  "batch_length": 50,
  "batch_size": 50,
  "device": "cuda:0",
  "grad_clip": 200,
  "grad_clip_ac": 200,
  "image_decoder_min_prob": 0,
  "amp": True,
  # Model
  "model": "dreamer",
  "deter_dim": 1024,  # For atari.
  "stoch_dim": 32,
  "stoch_discrete": 32,
  "hidden_dim": 1000,
  "gru_layers": 1,
  "gru_type": "gru",
  "layer_norm": True,
  "vecobs_size": 0,
  "image_encoder": "cnn",
  "cnn_depth": 48,
  "image_encoder_layers": 0,
  "image_decoder": "cnn",
  "image_decoder_layers": 0,
  "reward_input": False,
  "reward_decoder_layers": 4,
  "reward_decoder_categorical": None,
  "terminal_decoder_layers": 4,
  "map_stoch_dim": 64,
  "probe_model": "none",
  "map_decoder": "dense",
  "map_hidden_layers": 4,
  "map_hidden_dim": 1024,
  # Data.
  "run_number": 1,
  # Actor Critic
  "gamma": 0.995,
  "lambda_gae": 0.95,
  "entropy": 0.003,
  "target_interval": 100,
  "imag_horizon": 15,
  "actor_grad": "reinforce",
  "actor_dist": "onehot",
}


def get_environment(config):
  environment = atari.environment(
    game=config["game_name"],
    stack_size=config["stack_size"],
  )  # NOTE use sticky action
  return environment


# ===== Dataset & Buffer ===== #
def get_data_loader(
  config, environment, data_dir=None, half=True, use_local=True
) -> Iterator:
  """Get trajectory data loader."""
  environment_spec = make_environment_spec(environment)
  trajectory_length = config.batch_length

  dataset = atari.create_atari_ds_loader(
    game=config["game_name"],
    run_number=config["run_number"],
    data_dir=data_dir or DATA_DIR,
    num_actions=environment_spec.actions.num_values,
    stack_size=config["stack_size"],
    image_size=config["image_size"],
    data_percent=config["data_percentage"],
    trajectory_length=trajectory_length,
    shuffle_num_steps=50000,
    use_local=use_local,
  )

  def _preprocess_transform(data):
    # Transpose is to convert [B, T] -> [T, B]
    dtype = tf.float16 if half and config.amp else tf.float32
    new_data = {}

    image = tf.cast(data["observation"], dtype)
    image = image / 255.0 - 0.5  # type: ignore
    image = tf.transpose(
      image, perm=(1, 0, 4, 2, 3)
    )  # (B, T, H, W, C) => (T, B, C, H, W)
    new_data["image"] = image

    new_data["reward"] = tf.cast(
      tf.transpose(data["reward"], perm=(1, 0)), tf.float32
    )

    new_data["terminal"] = tf.cast(
      tf.transpose(data["is_terminal"], perm=(1, 0)), tf.float32
    )

    new_data["reset"] = tf.transpose(data["is_terminal"], perm=(1, 0))

    # new_data["action"] = tf.cast(
    #   tf.eye(environment_spec.actions.num_values,
    #          dtype=tf.float32)[tf.transpose(data["action"], perm=(1, 0))],
    #   dtype
    # )

    new_data["action"] = tf.transpose(data["action"], perm=(1, 0))

    return new_data

  dataset = (
    dataset.repeat().batch(
      config["batch_size"]
    ).map(_preprocess_transform).prefetch(tf.data.AUTOTUNE)
  )
  options = tf.data.Options()
  options.threading.max_intra_op_parallelism = 1
  dataset = dataset.with_options(options)
  iterator = dataset.as_numpy_iterator()
  return iterator


def preprare_pessimistic_mdp(
  models: List[Dreamer], dataloader, config, num_batch=100
):
  """Estimate a maximum delta over sampled transitions."""

  def _preprocess(data):
    data["action"] = preprocessing.to_onehot(data["action"], config.action_dim)
    return data

  logging.info("Burn in pessimistic MDP.")
  delta = torch.tensor(0.)
  for _ in tqdm(range(num_batch)):
    batch = _preprocess(next(dataloader))
    for i1, m1 in enumerate(models):
      pred1 = m1.get_features()
      for i2, m2 in enumerate(models):
        if i2 > i1:
          pred2 = m2.get_features()
          disagreement: Tensor = (pred1 - pred2).norm(dim=-1)
          delta = torch.maximum(delta, disagreement.max())
          print(delta)
  logging.info(f"Max delta = {delta.item()}.")


def main(_):
  config = ConfigDict()
  config.update(CONFIG)
  config.image_channels = config.stack_size
  config.game_name = "MsPacman"
  config.algo = "morel"
  config.data_percentage = 100
  config.negative_reward = -10

  ckpt_path = [
    "./wm/14kcebow-train_state_290001.pt",
    "./wm/2n4ccnd8-train_state_290001.pt",
    "./wm/2sgszd1k-train_state_290001.pt",
  ]
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  env = get_environment(config)
  env_spec = make_environment_spec(env)
  config.action_dim = env_spec.actions.num_values
  dataloader = get_data_loader(config, env, half=False)

  def _new_preprocess(data):
    data["action"] = preprocessing.to_onehot(data["action"], config.action_dim)
    return data

  cast_tensor = lambda x: torch.tensor(x)

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  models = [Dreamer(config) for _ in range(len(ckpt_path))]
  for i, m in enumerate(models):
    m.to(device).load_state_dict(torch.load(ckpt_path[i])["model_state_dict"])

  print("*" * 10, "Model loaded.", "*" * 10)
  delta = torch.tensor(0.).to(device)
  for _ in range(500):
    data_batch = next(dataloader)
    data_batch = _new_preprocess(data_batch)
    data_batch = tree.map_structure(
      lambda x: cast_tensor(x).to(device), data_batch
    )
    for idx_1, model_1 in enumerate(models):
      for idx_2, model_2 in enumerate(models):
        if idx_2 > idx_1:
          rnn_state = model_1.init_state(
            config.batch_size * config.iwae_samples
          )
          feat_1, _ = model_1.forward(data_batch, rnn_state)
          rnn_state = model_2.init_state(
            config.batch_size * config.iwae_samples
          )
          feat_2, _ = model_2.forward(data_batch, rnn_state)
          disagreement = torch.norm(feat_1 - feat_2, dim=-1).squeeze()
          delta = torch.maximum(delta, disagreement.max())
          print(delta)


if __name__ == "__main__":
  app.run(main)
"""
CUDA_VISIBLE_DEVICES=2 python src/rosmo/agent/mbrl/morel/utils.py

delta = 24.7545
"""
