# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Atari RL Unplugged datasets.

Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.

For every training run we have recorded all 50 million transitions corresponding
to 200 million environment steps (4x factor because of frame skipping). There
are 5 separate datasets for each of the 45 games.

Every transition in the dataset is a tuple containing the following features:

* o_t: Observation at time t. Observations have been processed using the
    canonical Atari frame processing, including 4x frame stacking. The shape
    of a single observation is [84, 84, 4].
* a_t: Action taken at time t.
* r_t: Reward after a_t.
* d_t: Discount after a_t.
* o_tp1: Observation at time t+1.
* a_tp1: Action at time t+1.
* extras:
  * episode_id: Episode identifier.
  * episode_return: Total episode return computed using per-step [-1, 1]
      clipping.
"""
import functools
import os
from typing import Dict

from acme import wrappers
import dm_env
from dm_env import specs
from dopamine.discrete_domains import atari_lib
import reverb
import tensorflow as tf


# 9 tuning games.
TUNING_SUITE = [
    'BeamRider',
    'DemonAttack',
    'DoubleDunk',
    'IceHockey',
    'MsPacman',
    'Pooyan',
    'RoadRunner',
    'Robotank',
    'Zaxxon',
]

# 36 testing games.
TESTING_SUITE = [
    'Alien',
    'Amidar',
    'Assault',
    'Asterix',
    'Atlantis',
    'BankHeist',
    'BattleZone',
    'Boxing',
    'Breakout',
    'Carnival',
    'Centipede',
    'ChopperCommand',
    'CrazyClimber',
    'Enduro',
    'FishingDerby',
    'Freeway',
    'Frostbite',
    'Gopher',
    'Gravitar',
    'Hero',
    'Jamesbond',
    'Kangaroo',
    'Krull',
    'KungFuMaster',
    'NameThisGame',
    'Phoenix',
    'Pong',
    'Qbert',
    'Riverraid',
    'Seaquest',
    'SpaceInvaders',
    'StarGunner',
    'TimePilot',
    'UpNDown',
    'VideoPinball',
    'WizardOfWor',
    'YarsRevenge',
]

# Total of 45 games.
ALL = TUNING_SUITE + TESTING_SUITE


def _decode_frames(pngs: tf.Tensor):
  """Decode PNGs.

  Args:
    pngs: String Tensor of size (4,) containing PNG encoded images.

  Returns:
    4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
  """
  # Statically unroll png decoding
  frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
  frames = tf.concat(frames, axis=2)
  frames.set_shape((84, 84, 4))
  return frames


def _make_reverb_sample(o_t: tf.Tensor,
                        a_t: tf.Tensor,
                        r_t: tf.Tensor,
                        d_t: tf.Tensor,
                        o_tp1: tf.Tensor,
                        a_tp1: tf.Tensor,
                        extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample:
  """Create Reverb sample with offline data.

  Args:
    o_t: Observation at time t.
    a_t: Action at time t.
    r_t: Reward at time t.
    d_t: Discount at time t.
    o_tp1: Observation at time t+1.
    a_tp1: Action at time t+1.
    extras: Dictionary with extra features.

  Returns:
    Replay sample with fake info: key=0, probability=1, table_size=0.
  """
  info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
                           probability=tf.constant(1.0, tf.float64),
                           table_size=tf.constant(0, tf.int64),
                           priority=tf.constant(1.0, tf.float64))
  data = (o_t, a_t, r_t, d_t, o_tp1, a_tp1, extras)
  return reverb.ReplaySample(info=info, data=data)


def _tf_example_to_reverb_sample(tf_example: tf.train.Example
                                 ) -> reverb.ReplaySample:
  """Create a Reverb replay sample from a TF example."""

  # Parse tf.Example.
  feature_description = {
      'o_t': tf.io.FixedLenFeature([4], tf.string),
      'o_tp1': tf.io.FixedLenFeature([4], tf.string),
      'a_t': tf.io.FixedLenFeature([], tf.int64),
      'a_tp1': tf.io.FixedLenFeature([], tf.int64),
      'r_t': tf.io.FixedLenFeature([], tf.float32),
      'd_t': tf.io.FixedLenFeature([], tf.float32),
      'episode_id': tf.io.FixedLenFeature([], tf.int64),
      'episode_return': tf.io.FixedLenFeature([], tf.float32),
  }
  data = tf.io.parse_single_example(tf_example, feature_description)

  # Process data.
  o_t = _decode_frames(data['o_t'])
  o_tp1 = _decode_frames(data['o_tp1'])
  a_t = tf.cast(data['a_t'], tf.int32)
  a_tp1 = tf.cast(data['a_tp1'], tf.int32)
  episode_id = tf.bitcast(data['episode_id'], tf.uint64)

  # Build Reverb replay sample.
  extras = {
      'episode_id': episode_id,
      'return': data['episode_return']
  }
  return _make_reverb_sample(o_t, a_t, data['r_t'], data['d_t'], o_tp1, a_tp1,
                             extras)


def dataset(path: str,
            game: str,
            run: int,
            num_shards: int = 100,
            shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
  """TF dataset of Atari SARSA tuples."""
  path = os.path.join(path, f'{game}/run_{run}')
  filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
  file_ds = tf.data.Dataset.from_tensor_slices(filenames)
  file_ds = file_ds.repeat().shuffle(num_shards)
  example_ds = file_ds.interleave(
      functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),
      cycle_length=tf.data.experimental.AUTOTUNE,
      block_length=5)
  example_ds = example_ds.shuffle(shuffle_buffer_size)
  return example_ds.map(_tf_example_to_reverb_sample,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)


class AtariDopamineWrapper(dm_env.Environment):
  """Wrapper for Atari Dopamine environmnet."""

  def __init__(self, env, max_episode_steps=108000):
    self._env = env
    self._max_episode_steps = max_episode_steps
    self._episode_steps = 0
    self._reset_next_episode = True

  def reset(self):
    self._episode_steps = 0
    self._reset_next_step = False
    observation = self._env.reset()
    return dm_env.restart(observation.squeeze(-1))

  def step(self, action):
    if self._reset_next_step:
      return self.reset()

    observation, reward, terminal, _ = self._env.step(action.item())
    observation = observation.squeeze(-1)
    discount = 1 - float(terminal)
    self._episode_steps += 1
    if terminal:
      self._reset_next_episode = True
      return dm_env.termination(reward, observation)
    elif self._episode_steps == self._max_episode_steps:
      self._reset_next_episode = True
      return dm_env.truncation(reward, observation, discount)
    else:
      return dm_env.transition(reward, observation, discount)

  def observation_spec(self):
    space = self._env.observation_space
    return specs.Array(space.shape[:-1], space.dtype)

  def action_spec(self):
    return specs.DiscreteArray(self._env.action_space.n)


def environment(game: str) -> dm_env.Environment:
  """Atari environment."""
  env = atari_lib.create_atari_environment(game_name=game,
                                           sticky_actions=True)
  env = AtariDopamineWrapper(env)
  env = wrappers.FrameStackingWrapper(env, num_frames=4)
  return wrappers.SinglePrecisionWrapper(env)
