"""Atari RL Unplugged datasets adapted from official codes.

Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.

For every training run we have recorded all 50 million transitions
corresponding to 200 million environment steps (4x factor because of
frame skipping). There are 5 separate datasets for each of the 45 games.

Every transition in the dataset is a tuple containing the following features:

* o_t: Observation at time t. Observations have been processed using the
    canonical Atari frame processing, including 4x frame stacking. The shape
    of a single observation is [84, 84, 4].
* a_t: Action taken at time t.
* r_t: Reward after a_t.
* d_t: Discount after a_t.
* o_tp1: Observation at time t+1.
* a_tp1: Action at time t+1.
* extras:
  * episode_id: Episode identifier.
  * episode_return: Total episode return computed using per-step [-1, 1]
      clipping.
"""
import collections
import functools
import json
import os
import random
from typing import Any, Dict

import dm_env
import gym
import numpy as np
import reverb
import rlds
import tensorflow as tf
import tensorflow_datasets as tfds
import tree
from absl import flags, logging
from acme import wrappers
from dm_env import specs
from dm_env import specs as dm_env_specs
from dopamine.discrete_domains import atari_lib
from skimage.transform import resize

FLAGS = flags.FLAGS
flags.DEFINE_list("ckpt_number", None, "Specific checkpoints to take.")

with open(
  os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "rl_unplugged_atari_baselines.json"
  ),
  "r",
) as f:
  BASELINES = json.load(f)

# 9 tuning games.
TUNING_SUITE = [
  "BeamRider",
  "DemonAttack",
  "DoubleDunk",
  "IceHockey",
  "MsPacman",
  "Pooyan",
  "RoadRunner",
  "Robotank",
  "Zaxxon",
]

# 36 testing games.
TESTING_SUITE = [
  "Alien",
  "Amidar",
  "Assault",
  "Asterix",
  "Atlantis",
  "BankHeist",
  "BattleZone",
  "Boxing",
  "Breakout",
  "Carnival",
  "Centipede",
  "ChopperCommand",
  "CrazyClimber",
  "Enduro",
  "FishingDerby",
  "Freeway",
  "Frostbite",
  "Gopher",
  "Gravitar",
  "Hero",
  "Jamesbond",
  "Kangaroo",
  "Krull",
  "KungFuMaster",
  "NameThisGame",
  "Phoenix",
  "Pong",
  "Qbert",
  "Riverraid",
  "Seaquest",
  "SpaceInvaders",
  "StarGunner",
  "TimePilot",
  "UpNDown",
  "VideoPinball",
  "WizardOfWor",
  "YarsRevenge",
]

_GAMES = [
  "Alien",
  "Amidar",
  "Assault",
  "Asterix",
  "Atlantis",
  "BankHeist",
  "BattleZone",
  "BeamRider",
  "Boxing",
  "Breakout",
  "Carnival",
  "Centipede",
  "ChopperCommand",
  "CrazyClimber",
  "DemonAttack",
  "DoubleDunk",
  "Enduro",
  "FishingDerby",
  "Freeway",
  "Frostbite",
  "Gopher",
  "Gravitar",
  "Hero",
  "IceHockey",
  "Jamesbond",
  "Kangaroo",
  "Krull",
  "KungFuMaster",
  "MsPacman",
  "NameThisGame",
  "Phoenix",
  "Pong",
  "Pooyan",
  "Qbert",
  "Riverraid",
  "RoadRunner",
  "Robotank",
  "Seaquest",
  "SpaceInvaders",
  "StarGunner",
  "TimePilot",
  "UpNDown",
  "VideoPinball",
  "WizardOfWor",
  "YarsRevenge",
  "Zaxxon",
]

_SHORT_GAMES = [
  "Carnival",
  "Gravitar",
  "StarGunner",
]

# Total of 45 games.
ALL = TUNING_SUITE + TESTING_SUITE


def _decode_frames(pngs: tf.Tensor):
  """Decode PNGs.

    Args:
      pngs: String Tensor of size (4,) containing PNG encoded images.

    Returns:
      4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
    """
  # Statically unroll png decoding
  frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
  frames = tf.concat(frames, axis=2)
  frames.set_shape((84, 84, 4))
  return frames


def _make_reverb_sample(
  o_t: tf.Tensor,
  a_t: tf.Tensor,
  r_t: tf.Tensor,
  d_t: tf.Tensor,
  o_tp1: tf.Tensor,
  a_tp1: tf.Tensor,
  extras: Dict[str, tf.Tensor],
) -> reverb.ReplaySample:
  """Create Reverb sample with offline data.

    Args:
      o_t: Observation at time t.
      a_t: Action at time t.
      r_t: Reward at time t.
      d_t: Discount at time t.
      o_tp1: Observation at time t+1.
      a_tp1: Action at time t+1.
      extras: Dictionary with extra features.

    Returns:
      Replay sample with fake info: key=0, probability=1, table_size=0.
    """
  info = reverb.SampleInfo(
    key=tf.constant(0, tf.uint64),
    probability=tf.constant(1.0, tf.float64),
    table_size=tf.constant(0, tf.int64),
    priority=tf.constant(1.0, tf.float64),
  )
  data = (o_t, a_t, r_t, d_t, o_tp1, a_tp1, extras)

  return reverb.ReplaySample(info=info, data=data)


def _tf_example_to_reverb_sample(
  tf_example: tf.train.Example
) -> reverb.ReplaySample:
  """Create a Reverb replay sample from a TF example."""

  # Parse tf.Example.
  feature_description = {
    "o_t": tf.io.FixedLenFeature([4], tf.string),
    "o_tp1": tf.io.FixedLenFeature([4], tf.string),
    "a_t": tf.io.FixedLenFeature([], tf.int64),
    "a_tp1": tf.io.FixedLenFeature([], tf.int64),
    "r_t": tf.io.FixedLenFeature([], tf.float32),
    "d_t": tf.io.FixedLenFeature([], tf.float32),
    "episode_id": tf.io.FixedLenFeature([], tf.int64),
    "episode_return": tf.io.FixedLenFeature([], tf.float32),
  }
  data = tf.io.parse_single_example(tf_example, feature_description)

  # Process data.
  o_t = _decode_frames(data["o_t"])
  o_tp1 = _decode_frames(data["o_tp1"])
  a_t = tf.cast(data["a_t"], tf.int32)
  a_tp1 = tf.cast(data["a_tp1"], tf.int32)
  episode_id = tf.bitcast(data["episode_id"], tf.uint64)

  # Build Reverb replay sample.
  extras = {"episode_id": episode_id, "return": data["episode_return"]}
  return _make_reverb_sample(
    o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1, extras
  )


def _get_num_shards(game: str, shards: int) -> int:
  if game in _SHORT_GAMES:
    return shards - 1
  else:
    return shards


# Parse tf.Example.
# Note that rewards and episode_return are actually also clipped.
_feature_description = {
  "checkpoint_idx":
    tf.io.FixedLenFeature([], tf.int64),
  "episode_idx":
    tf.io.FixedLenFeature([], tf.int64),
  "episode_return":
    tf.io.FixedLenFeature([], tf.float32),
  "clipped_episode_return":
    tf.io.FixedLenFeature([], tf.float32),
  "observations":
    tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True),
  "actions":
    tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
  "unclipped_rewards":
    tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
  "clipped_rewards":
    tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
  "discounts":
    tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
}


def _tf_example_to_reverb_episode_sample(
  tf_example: tf.train.Example,
) -> Dict[str, Any]:
  """Create ordered Reverb replay sample from a TF example.
    Adapted from tensorflow_dataset.
    """

  data = tf.io.parse_single_example(tf_example, _feature_description)

  # Process data.
  episode_length = tf.size(data["actions"])
  is_first = tf.concat([[True], [False] * tf.ones(episode_length - 1)], axis=0)
  is_last = tf.concat([[False] * tf.ones(episode_length - 1), [True]], axis=0)
  is_terminal = [False] * tf.ones_like(data["actions"])
  discounts = data["discounts"]
  if discounts[-1] == 0.0:
    is_terminal = tf.concat(
      [[False] * tf.ones(episode_length - 1, tf.int64), [True]], axis=0
    )
    # If the episode ends in a terminal state, in the last step only the
    # observation has valid information (the terminal state).
    discounts = tf.concat([discounts[1:], [0.0]], axis=0)
  episode = {
    # Episode Metadata
    "episode_id": data["episode_idx"],
    "checkpoint_id": data["checkpoint_idx"],
    "episode_return": data["episode_return"],
    "steps":
      {
        "observation": data["observations"],
        "action": tf.cast(data["actions"], tf.int32),
        "reward": data["unclipped_rewards"],
        "discount": discounts,
        "is_first": tf.cast(is_first, tf.int32),
        "is_last": tf.cast(is_last, tf.int32),
        "is_terminal": tf.cast(is_terminal, tf.int32),
      },
  }
  return episode


_TOTAL_SHARDS = 50


def dataset(
  path: str,
  game: str,
  run: int,
  num_shards: int = 100,
  total_shards: int = 100,
  shuffle_buffer_size: int = 100000,
) -> tf.data.Dataset:
  """TF dataset of Atari SARSA tuples."""
  path = os.path.join(path, f"{game}/run_{run}")
  filenames = [
    f"{path}-{i:05d}-of-{total_shards:05d}" for i in range(num_shards)
  ]
  file_ds = tf.data.Dataset.from_tensor_slices(filenames)
  file_ds = file_ds.repeat().shuffle(num_shards)
  example_ds = file_ds.interleave(
    functools.partial(tf.data.TFRecordDataset, compression_type="GZIP"),
    cycle_length=tf.data.experimental.AUTOTUNE,
    block_length=5,
  )
  example_ds = example_ds.shuffle(shuffle_buffer_size)
  return example_ds.map(
    _tf_example_to_reverb_sample,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
  )


def dataset_episode(
  game: str,
  run: int,
  path: str = "gs://rl_unplugged/atari_episodes_ordered",
  num_shards: int = 50,
  order: str = "original",
) -> tf.data.Dataset:
  """TF dataset of Atari Unplugged Episodes."""
  shard_num = 50
  assert order in ["random", "reversed", "original"]

  path = os.path.join(path, f"{game}/run_{run}")
  total_shards = _get_num_shards(game, _TOTAL_SHARDS)
  if shard_num > total_shards:
    shard_num = total_shards
  filenames = [
    f"{path}-{i:05d}-of-{total_shards:05d}" for i in range(shard_num)
  ]
  if order == "random":
    random.shuffle(filenames)
  elif order == "reversed":
    filenames = list(reversed(filenames))
  if num_shards > len(filenames):
    logging.warning(
      f"{num_shards} shards out of bound, clip to {len(filenames)} shards."
    )
    num_shards = len(filenames)
  filenames = filenames[:num_shards]
  file_ds = tf.data.Dataset.from_tensor_slices(filenames)
  example_ds = file_ds.flat_map(
    functools.partial(tf.data.TFRecordDataset, compression_type="GZIP"),
  )
  return example_ds.map(
    _tf_example_to_reverb_episode_sample,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
  )


class BatchToTransition(object):
  """Creates (s,a,r,f,l) transitions."""

  def __init__(self, image_size) -> None:
    self.image_size = image_size

  def create_transitions(self, batch):
    observation = tf.squeeze(batch[rlds.OBSERVATION], axis=-1)
    observation = tf.transpose(observation, perm=[1, 2, 0])
    if self.image_size != 84:
      observation = tf.image.resize(
        observation, (self.image_size, self.image_size), antialias=True
      )

    action = batch[rlds.ACTION][-1]
    reward = batch[rlds.REWARD][-1]
    discount = batch[rlds.DISCOUNT][-1]
    return {
      "observation": observation,
      "action": action,
      "reward": reward,
      "discount": discount,
      "is_first": batch[rlds.IS_FIRST][0],
      "is_last": batch[rlds.IS_LAST][-1],
      "is_terminal": batch[rlds.IS_TERMINAL][-1],
    }


def get_trajectory_dataset_fn(stack_size, image_size, trajectory_length=1):
  batch_fn = BatchToTransition(image_size).create_transitions

  def make_trajectory_dataset(episode):
    """Converts an episode of steps to a dataset of custom transitions.
        Episode spec: {
          'checkpoint_id': <tf.Tensor: shape=(), dtype=int64, numpy=0>,
          'episode_id': <tf.Tensor: shape=(), dtype=int64, numpy=0>,
          'episode_return': <tf.Tensor: shape=(), dtype=float32, numpy=0>,
          'steps': <_VariantDataset element_spec={
            'action': TensorSpec(shape=(), dtype=tf.int64, name=None),
            'discount': TensorSpec(shape=(), dtype=tf.float32, name=None),
            'is_first': TensorSpec(shape=(), dtype=tf.bool, name=None),
            'is_last': TensorSpec(shape=(), dtype=tf.bool, name=None),
            'is_terminal': TensorSpec(shape=(), dtype=tf.bool, name=None),
            'observation': TensorSpec(shape=(84, 84, 1), dtype=tf.uint8,
              name=None),
            'reward': TensorSpec(shape=(), dtype=tf.float32, name=None)
            }
          >}
        """
    # Create a dataset of 2-step sequences with overlap of 1.
    timesteps: tf.data.Dataset = episode[rlds.STEPS]
    batched_steps = rlds.transformations.batch(
      timesteps,
      size=stack_size,
      shift=1,
      drop_remainder=True,
    )
    transitions = batched_steps.map(batch_fn)
    # Batch trajectory.
    if trajectory_length > 1:
      transitions = transitions.repeat(2)
      transitions = transitions.skip(
        tf.random.uniform([], 0, trajectory_length, dtype=tf.int64)
      )
      trajectory = transitions.batch(trajectory_length, drop_remainder=True)
    else:
      trajectory = transitions
    return trajectory

  return make_trajectory_dataset


def uniformly_subsampled_atari_data(
  dataset_name,
  data_percent,
  data_dir,
  ckpt=None,
  use_local=True,
):
  if not use_local:
    ds_builder = tfds.builder(dataset_name)
  data_splits = []
  total_num_episode = 0
  ckpts = []
  try:
    ckpt = ckpt or FLAGS.ckpt_number
  except Exception:
    pass
  if ckpt is not None:
    ckpts = list(map(int, ckpt))
    logging.info(f"Checkpoints to take: {ckpts}")
  if not use_local:
    for i, (split, info) in enumerate(ds_builder.info.splits.items()):
      if ckpts and i not in ckpts:
        continue
      # Convert `data_percent` to number of episodes to allow
      # for fractional percentages.
      num_episodes = int((data_percent / 100) * info.num_examples)
      total_num_episode += num_episodes
      if num_episodes == 0:
        raise ValueError(f"{data_percent}% leads to 0 episodes in {split}!")
      # Sample first `data_percent` episodes from each of the data split.
      data_splits.append(f"{split}[:{num_episodes}]")
    # Interleave episodes across different splits/checkpoints.
    # Set `shuffle_files=True` to shuffle episodes across files within splits.
    logging.info(f"Total number of episode = {total_num_episode}")
  else:
    data_splits = ["checkpoint_" + str(i).zfill(2) for i in range(50)]
    logging.info(data_splits)
  read_config = tfds.ReadConfig(
    interleave_cycle_length=len(data_splits),
    shuffle_reshuffle_each_iteration=True,
    enable_ordering_guard=False,
  )
  if not use_local:
    print("try to load")
    return tfds.load(
      dataset_name,
      data_dir=data_dir,
      split="+".join(data_splits),
      read_config=read_config,
      shuffle_files=True,
    )
  else:
    # data_dir = "./datasets"
    data_dir = "/datasets/rl_unplugged/tensorflow_datasets"
    return tfds.builder_from_directory(
      os.path.join(data_dir, dataset_name, "1.1.0")
    ).as_dataset(
      split="+".join(data_splits),
      read_config=read_config,
    )


def create_atari_ds_loader(
  game,
  run_number,
  data_dir,
  num_actions=0,
  stack_size=4,
  image_size=84,
  data_percent=10,
  trajectory_fn=None,
  shuffle_num_episodes=1000,
  shuffle_num_steps=50000,
  trajectory_length=10,
  use_local=False,
):
  del num_actions
  if trajectory_fn is None:
    trajectory_fn = get_trajectory_dataset_fn(
      stack_size, image_size, trajectory_length
    )
  dataset_name = f"rlu_atari_checkpoints_ordered/{game}_run_{run_number}"
  # Create a dataset of episodes sampling `data_percent`% episodes
  # from each of the data split.
  dataset = uniformly_subsampled_atari_data(
    dataset_name, data_percent, data_dir, use_local=use_local
  )
  # Shuffle the episodes to avoid consecutive episodes.
  dataset = dataset.shuffle(shuffle_num_episodes)
  # Interleave=1 keeps ordered sequential steps.
  dataset = dataset.interleave(
    trajectory_fn,
    cycle_length=100,
    block_length=1,
    deterministic=False,
    num_parallel_calls=tf.data.AUTOTUNE,
  )
  # Shuffle trajectories in the dataset.
  dataset = dataset.shuffle(
    shuffle_num_steps // trajectory_length,
    reshuffle_each_iteration=True,
  )
  return dataset


class AtariDopamineWrapper(dm_env.Environment):
  """Wrapper for Atari Dopamine environment."""

  def __init__(self, env, max_episode_steps=108000):
    self._env = env
    self._max_episode_steps = max_episode_steps
    self._episode_steps = 0
    self._reset_next_episode = True

  def reset(self):
    self._episode_steps = 0
    self._reset_next_step = False
    observation = self._env.reset()
    return dm_env.restart(observation.squeeze(-1))

  def step(self, action):
    if self._reset_next_step:
      return self.reset()
    if not isinstance(action, int):
      action = action.item()
    observation, reward, terminal, _ = self._env.step(action)
    observation = observation.squeeze(-1)
    discount = 1 - float(terminal)
    self._episode_steps += 1
    if terminal:
      self._reset_next_episode = True
      return dm_env.termination(reward, observation)
    elif self._episode_steps == self._max_episode_steps:
      self._reset_next_episode = True
      return dm_env.truncation(reward, observation, discount)
    else:
      return dm_env.transition(reward, observation, discount)

  def observation_spec(self):
    space = self._env.observation_space
    return specs.Array(space.shape[:-1], space.dtype)

  def action_spec(self):
    return specs.DiscreteArray(self._env.action_space.n)

  def render(self, mode="rgb_array"):
    return self._env.render(mode)


class FrameActionStacker:

  def __init__(self, num_frames: int) -> None:
    self._num_frames = num_frames
    self.reset()

  @property
  def num_frames(self) -> int:
    return self._num_frames

  def reset(self):
    self._frame_stack = collections.deque(maxlen=self._num_frames)
    self._action_stack = collections.deque(maxlen=self._num_frames)

  def step(self, frame: np.ndarray, action: np.ndarray) -> np.ndarray:
    """Append frame and the action leading to that frame,
        and return the stack."""
    pass

  def update_spec(self, spec: dm_env_specs.Array) -> dm_env_specs.Array:
    new_shape = spec.shape + (self._num_frames,)


class FrameActionStackingWrapper(wrappers.FrameStackingWrapper):

  def __init__(self, environment: dm_env.Environment, num_frames: int = 4):
    self._environment = environment
    original_spec = self._environment.observation_spec()
    self._stackers = tree.map_structure(
      lambda _: FrameActionStacker(num_frames=num_frames),
      self._environment.observation_spec(),
    )
    self._observation_spec = tree.map_structure(
      lambda stacker, spec: stacker.update_spec(spec),
      self._stackers,
      original_spec,
    )

  def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
    pass


def create_atari_environment(
  game_name=None, sticky_actions=True, screen_size=84
):
  """Wraps an Atari 2600 Gym environment with some basic preprocessing.

  This preprocessing matches the guidelines proposed in Machado et al. (2017),
  "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open
  Problems for General Agents".

  The created environment is the Gym wrapper around the Arcade Learning
  Environment.

  The main choice available to the user is whether to use sticky actions or not.
  Sticky actions, as prescribed by Machado et al., cause actions to persist
  with some probability (0.25) when a new command is sent to the ALE. This
  can be viewed as introducing a mild form of stochasticity in the environment.
  We use them by default.

  Args:
    game_name: str, the name of the Atari 2600 domain.
    sticky_actions: bool, whether to use sticky_actions as per Machado et al.

  Returns:
    An Atari 2600 environment with some standard preprocessing.
  """
  assert game_name is not None
  game_version = 'v0' if sticky_actions else 'v4'
  full_game_name = '{}NoFrameskip-{}'.format(game_name, game_version)
  env = gym.make(full_game_name)
  # Strip out the TimeLimit wrapper from Gym, which caps us at 100k frames. We
  # handle this time limit internally instead, which lets us cap at 108k frames
  # (30 minutes). The TimeLimit wrapper also plays poorly with saving and
  # restoring states.
  env = env.env
  env = atari_lib.AtariPreprocessing(env, screen_size=screen_size)
  return env


class ResizeWrapper(wrappers.EnvironmentWrapper):

  def __init__(self, environment: dm_env.Environment, frame_size: int):
    self._environment = environment
    self._frame_size = frame_size

  def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
    return timestep._replace(
      observation=resize(
        timestep.observation, (self._frame_size, self._frame_size)
      )
    )

  def step(self, action) -> dm_env.TimeStep:
    return self._convert_timestep(self._environment.step(action))

  def reset(self) -> dm_env.TimeStep:
    return self._convert_timestep(self._environment.reset())

  def observation_spec(self):
    spec = self._environment.observation_spec()
    new_shape = (self._frame_size, self._frame_size)
    return spec.replace(shape=new_shape)


def environment(
  game: str, stack_size: int, screen_size=84
) -> dm_env.Environment:
  """Atari environment."""
  env = create_atari_environment(
    game_name=game, sticky_actions=True, screen_size=screen_size
  )
  env = AtariDopamineWrapper(env, max_episode_steps=20_000)
  if screen_size != 84:
    env = ResizeWrapper(env, screen_size)
  env = wrappers.FrameStackingWrapper(env, num_frames=stack_size)
  return wrappers.SinglePrecisionWrapper(env)

