# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multi-task TFDS dataset."""

import rlds
import tensorflow_datasets as tfds

import gc
from absl import logging
import gin
import numpy as np
import tensorflow as tf
from multi_task_atari import atari_config

import jax
import collections


MEAN_DICT = {}
STD_DICT = {}


def get_data(step):
  data = {rlds.REWARD: step[rlds.REWARD]}
  if step[rlds.IS_LAST]:
    mask = {
        rlds.REWARD: False,
    }
  else:
    mask = {
        rlds.REWARD: True,
    }
  return data, mask


class BatchToTransition(object):
  """Creates (s,a,r,s',a') transitions."""

  def __init__(self, stack_size, update_horizon, gamma):
    self.stack_size = stack_size
    self.update_horizon = update_horizon
    self.total_frames = stack_size + update_horizon
    self.cumulative_discount = tf.pow(gamma, range(update_horizon))

  def create_transitions(self, batch):
    all_states = tf.squeeze(batch[rlds.OBSERVATION], axis=-1)
    all_states = tf.transpose(all_states, perm=[1, 2, 0])
    rewards = batch[rlds.REWARD][self.stack_size-1:-1]
    terminals = batch[rlds.IS_TERMINAL][self.stack_size: self.total_frames]
    return {
        'state': all_states[:, :, :self.stack_size],
        'action': batch[rlds.ACTION][self.stack_size-1],
        'reward': tf.reduce_sum(rewards * self.cumulative_discount),
        'next_state': all_states[:, :, self.update_horizon:],
        'terminal': tf.reduce_any(terminals),
        'next_action': batch[rlds.ACTION][self.total_frames - 1],
    }


def get_transition_dataset_fn(stack_size, update_horizon=1, gamma=0.99):
  batch_fn = BatchToTransition(
      stack_size, update_horizon, gamma).create_transitions
  def make_transition_dataset(episode):
    """Converts an episode of steps to a dataset of custom transitions."""
    # Create a dataset of 2-step sequences with overlap of 1.
    batched_steps = rlds.transformations.batch(
        episode[rlds.STEPS], size=stack_size+update_horizon, shift=1, drop_remainder=True)
    return batched_steps.map(batch_fn)
  return make_transition_dataset

def get_dummy_transition_dataset_fn():
  """Returns episode and checkpoint for checking how well shuffled a dataset is."""
  def make_transition_dataset(episode):
    """Converts an episode of steps to a dataset of ids."""
    # Create a dataset of 2-step sequences with overlap of 1.
    episode_info = tf.data.Dataset.from_tensors(
        (episode['episode_id'], episode['checkpoint_id'])).repeat()
    return episode_info
  return make_transition_dataset

def first_k_percent_data(dataset_name, data_percent):
  ds_builder = tfds.builder(dataset_name)
  data_splits = []

  num_splits = len(ds_builder.info.splits.keys())
  print ('Num splits: ', num_splits)

  num_splits_to_use = np.ceil(data_percent / 100.0 * num_splits)
  print ('Num splits to use: ', num_splits_to_use)

  idx = 0
  for split, info in ds_builder.info.splits.items():
    if idx < num_splits_to_use:
      print ('Current split: ', split, info)
      num_episodes = int(info.num_examples)
      if num_episodes == 0:
        raise ValueError(f'{data_percent}% leads to 0 episodes in {split}!')
      print (split, num_episodes)
      data_splits.append(f'{split}[:{num_episodes}]')
      idx += 1

  read_config = tfds.ReadConfig(
      interleave_cycle_length=len(data_splits),
      shuffle_reshuffle_each_iteration=True,
      enable_ordering_guard=False,
      shuffle_seed=jax.process_index())
  return tfds.load(dataset_name, split='+'.join(data_splits),
                   read_config=read_config, shuffle_files=True)


def expert_data(dataset_name):
  ds_builder = tfds.builder(dataset_name)
  data_splits = []

  num_splits = len(ds_builder.info.splits.keys())
  print ('Num splits: ', num_splits)

  num_splits_to_use = 1
  print ('Num splits to use: ', num_splits_to_use)

  idx = 0
  for split, info in ds_builder.info.splits.items():
    if idx >= num_splits - num_splits_to_use:
      num_episodes = int(info.num_examples)
      if num_episodes == 0:
        raise ValueError(f'{data_percent}% leads to 0 episodes in {split}!')
      print (split, num_episodes)
      data_splits.append(f'{split}[:{num_episodes}]')
    idx += 1

  read_config = tfds.ReadConfig(
      interleave_cycle_length=len(data_splits),
      shuffle_reshuffle_each_iteration=True,
      enable_ordering_guard=False,
      shuffle_seed=jax.process_index())
  return tfds.load(dataset_name, split='+'.join(data_splits),
                   read_config=read_config, shuffle_files=True)


def uniformly_subsampled_atari_data(dataset_name, data_percent):
  ds_builder = tfds.builder(dataset_name)
  data_splits = []
  for split, info in ds_builder.info.splits.items():
    # Convert `data_percent` to number of episodes to allow
    # for fractional percentages.
    num_episodes = int((data_percent/100) * info.num_examples)
    if num_episodes == 0:
      raise ValueError(f'{data_percent}% leads to 0 episodes in {split}!')
    # Sample first `data_percent` episodes from each of the data split
    data_splits.append(f'{split}[:{num_episodes}]')
    # Interleave episodes across different splits/checkpoints
    # Set `shuffle_files=True` to shuffle episodes across files within splits
  read_config = tfds.ReadConfig(
      interleave_cycle_length=len(data_splits),
      shuffle_reshuffle_each_iteration=True,
      enable_ordering_guard=False,
      shuffle_seed=jax.process_index())
  return tfds.load(dataset_name, split='+'.join(data_splits),
                   read_config=read_config, shuffle_files=True)


def compute_average_episode_return(dataset):
  """Compute the return in the dataset"""
  total_trajs = 0.0
  total_return = 0.0
  for ep in iter(dataset):
    total_return += ep['episode_return']
    total_trajs += 1

  print ('Total trajs: ', total_trajs)
  print ('Total return: ', total_return)
  print ('Average return: ', total_return / total_trajs)
  return {
      'avg_return': total_return / total_trajs,
      'total_return': total_return,
      'total_trajs': total_trajs
  }


def create_atari_ds_loader(game, run_number,
                           data_percent=10,
                           transition_fn=None,
                           shuffle_num_episodes=1000,
                           shuffle_num_steps=50000,
                           cycle_length=100,
                           replay_type='uniform',
                           process_or_not=True):
  if transition_fn is None:
    transition_fn = get_transition_dataset_fn(4)
  dataset_name = f'rlu_atari_checkpoints_ordered/{game}_run_{run_number}'

  if replay_type == 'uniform':
    print ('Loading uniform data....')
    dataset = uniformly_subsampled_atari_data(dataset_name, data_percent)
    return_stats = compute_average_episode_return(dataset)
    del dataset
    gc.collect()
    dataset = uniformly_subsampled_atari_data(dataset_name, data_percent)
  elif replay_type == 'initial':
    print ('Loading initial data....')
    dataset = first_k_percent_data(dataset_name, data_percent)
    return_stats = compute_average_episode_return(dataset)
    del dataset
    gc.collect()
    dataset = first_k_percent_data(dataset_name, data_percent)
  else:
    raise RuntimeError('Not a valid replay type which is supported.....')

  # Compute mean and Std
  mean = 0.0
  std = {'reward': 1.0}

  print ('Mean and std of buffer: ', game, mean, std)

  gc.collect()

  # Shuffle the episodes to avoid consecutive episodes
  dataset = dataset.shuffle(shuffle_num_episodes)
  # Interleave the steps across many different episodes
  dataset = dataset.interleave(
      transition_fn, cycle_length=cycle_length, block_length=1,
      deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
  # Shuffle steps in the dataset
  shuffled_dataset = dataset.shuffle(
      shuffle_num_steps, reshuffle_each_iteration=True)

  return shuffled_dataset, return_stats, {
      'reward_mean': mean,
      'reward_std': std}

def _create_ds_iterator(ds, batch_size=32):
  """Create numpy iterator from a tf dataset `ds`."""
  batch_ds = ds.repeat().batch(batch_size).prefetch(tf.data.AUTOTUNE)
  return batch_ds.as_numpy_iterator()


def build_tfds_replay(game,
                      data_percent,
                      stack_size,
                      update_horizon,
                      gamma,
                      batch_size=32,
                      replay_type='uniform',
                      process_or_not=True):
  transition_fn = get_transition_dataset_fn(stack_size, update_horizon, gamma)
  atari_ds, return_stats, reward_stats = create_atari_ds_loader(
      game=game,
      run_number=1, data_percent=data_percent, transition_fn=transition_fn,
      replay_type=replay_type, process_or_not=process_or_not)
  return _create_ds_iterator(atari_ds, batch_size), return_stats, reward_stats



@gin.configurable
class JaxMultiTaskFixedReplayBufferTFDS(object):
  """Replay Buffers for loading existing data."""

  def __init__(self,
               observation_shape,
               stack_size,
               replay_capacity,
               batch_size,
               num_games=1,
               game_names=('Asterix',),
               replay_suffix=None,
               replay_file_start_index=0,
               replay_file_end_index=None,
               replay_transitions_start_index=0,
               num_buffers_to_load=5,
               update_horizon=1,
               gamma=0.99,
               observation_dtype=np.uint8,
               num_devices=1,
               use_single_game_action_space=False,
               memory_saver=True,
               replay_type='uniform',
               sample_single_game_id=-1):
    """Initialize the JaxMultiTaskFixedReplayBuffer class.

    Args:
      data_dir: str, log Directory from which to load the replay buffer.
      observation_shape: tuple of ints.
      stack_size: int, number of frames to use in state stack.
      replay_capacity: int, number of transitions to keep in memory. This can be
        used with `replay_transitions_start_index` to read a subset of replay
        data starting from a specific position.
      batch_size: int, Batch size for sampling data from buffer.
      num_games: int, Number of games to train on".
      replay_suffix: int, If not None, then only load the replay buffer
        corresponding to the specific suffix in data directory.
      replay_file_start_index: int, Starting index of the replay buffer to use.
      replay_file_end_index: int, End index of the replay buffer to use.
      replay_transitions_start_index: int, Starting index for loading the data
        from files in `data_dir`. This can be used to read a file starting from
        any index.
      num_buffers_to_load: int, number of replay buffers to load randomly in
        memory at every iteration from all buffers saved in `data_dir`.
      update_horizon: int, length of update ('n' in n-step update).
      gamma: int, the discount factor.
      observation_dtype: np.dtype, type of the observations. Defaults to
        np.uint8 for Atari 2600.
    """

    logging.info('Creating %s with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay_transitions_start_index %d',
                 replay_transitions_start_index)
    logging.info('\t replay_file_start_index %d', replay_file_start_index)
    logging.info('\t replay_file_end_index %s', replay_file_end_index)
    logging.info('\t replay_suffix %s', replay_suffix)
    logging.info('\t number of games to study, %d', num_games)
    logging.info('\t num devices: %s', num_devices)
    logging.info('\t use single game action space: %s', use_single_game_action_space)

    self._num_games = num_games
    self._game_names = game_names
    self._num_devices = num_devices
    self._replay_suffix = replay_suffix
    print ('Devices / games in buffer: ', self._num_games, self._game_names,
           self._num_devices, use_single_game_action_space, batch_size)
    self._memory_saver = memory_saver
    assert len(game_names) == self._num_games, "Not matching games found"

    assert num_devices == len(jax.local_devices()), "Made sure there are equal jax devices"

    self._data_percent = float(replay_capacity) / (1000000.0) * 100.0
    self._replay_capacity = replay_capacity
    self._all_buffers = []
    self._all_buffers_stats = []
    self._all_buffer_reward_stats = []

    self._process_buffers_to_sample_from()

    self.sample_single_game_id = sample_single_game_id

    for idx in range(num_games):
      print ('Buffer created: ', idx, game_names[idx],
             self._data_percent, replay_type)
      buffer, buffer_stats, reward_stats = build_tfds_replay(
          game=game_names[idx],
          data_percent=self._data_percent,
          stack_size=stack_size,
          update_horizon=update_horizon,
          gamma=gamma,
          batch_size=batch_size,
          replay_type=replay_type,
          process_or_not=bool(idx in self._games_to_choose),
      )
      self._all_buffers.append(buffer)
      self._all_buffers_stats.append(buffer_stats)
      self._all_buffer_reward_stats.append(reward_stats)

    # Process the buffer return for computing the weights
    self._process_buffer_stats()

  def _process_buffers_to_sample_from(self,):
    current_process_index = jax.process_index()
    total_processes = int (jax.device_count() // jax.local_device_count())
    self._games_to_choose = np.arange(self._num_games)
    if self._memory_saver:
      self._games_to_choose = np.array_split(
          np.arange(self._num_games),
          indices_or_sections=total_processes)[current_process_index]
    print ('Number of games for each process: ',
           current_process_index, self._games_to_choose)

  def compute_reward_scaling(self,):
    """Compute reward scaling for the rewards in the transition."""
    self._stds = np.zeros((len(self._all_buffers),), dtype=np.float32)
    self._means = np.zeros((len(self._all_buffers),), dtype=np.float32)
    for idx, buffer in enumerate(self._all_buffers):
      num_transitions = 0.0

      import pdb; pdb.set_trace()
      # First compute mean reward
      for transitions in iter(buffer):
        self._means[idx] += np.mean(transitions['reward'])
        num_transitions += 1.0

      self._means[idx] = self._means[idx] / num_transitions

      import pdb; pdb.set_trace()
      # Now compute std of reward
      total_denom = 0.9
      for transition in iter(buffer):
        self._stds[idx] += np.sum((transition['reward'] - self._means[idx])**2)
        total_denom += transition['reward'].shape[0]

      self._stds[idx] = np.sqrt(self._stds[idx] / (total_denom - 1))
      print ('Average of ', idx, 'buffer is:', self._means[idx])
      print ('Std of ', idx, 'buffer is: ', self._stds[idx])

  def _process_buffer_stats(self):
    """Make the median weight 1.0, and all other weights can be something else."""
    all_weights = np.array(
        [np.abs(buffer_idx[
            'avg_return']) + 1e-3 for buffer_idx in self._all_buffers_stats])
    median_return = np.median(all_weights)
    weights_to_use = median_return / all_weights
    self._task_weights = weights_to_use
    print ('Using task weights: ', self._task_weights)

  def _load_buffer(self, suffix):
    """Not needed with tfds datasets."""
    pass

  def load_single_buffer(self, suffix):
    """Not needed with tfds datasets"""
    pass

  def switch_action_space(self, batch, game_idx):
    """Switch action space with the game."""
    action_space_map = atari_config.GAME_TO_FULL_ACTION_SET[
        self._game_names[game_idx]]
    temp_indices = np.zeros((batch['action'].shape[0], self._num_games))
    temp_indices[:, game_idx] = 1.0
    batch['action'] = np.array([action_space_map[x] for x in batch['action']],
                               dtype=np.int32)
    batch['task_id'] = temp_indices
    batch['loss_weight'] = 0.0 * self._task_weights[game_idx] + 1.0
    batch['reward_scaling'] = float(
        self._all_buffer_reward_stats[game_idx]['reward_std']['reward'])
    return batch

  def _load_replay_buffers(self, num_buffers):
    """Not needed with tfds datasets"""
    pass

  def get_transition_elements(self):
    return self._all_buffers[0]._replay_buffers[0].get_transition_elements()

  def sample_transition_batch(self,):
    all_batches = []
    for idx in range(self._num_devices):
      # game_index = np.random.randint(self._num_games)
      game_index = np.random.choice(self._games_to_choose)
      batch_game = next(self._all_buffers[game_index])
      batch_game = self.switch_action_space(batch_game, game_idx=game_index)
      all_batches.append(batch_game)

    all_batch_dict = {}
    for key in batch_game:
      all_batch_dict[key] = np.stack(
        [x[key] for x in all_batches], axis=0
      )

    return all_batch_dict

  def sample_batch_for_one_device(self,):
    # game_index = np.random.randint(self._num_games)
    if self.sample_single_game_id < 0:
      # Choose game from all the games
      game_index = np.random.choice(self._games_to_choose)
    else:
      # Only choose the data from the target game if the game is set to 1
      game_index = self.sample_single_game_id
    batch_game = next(self._all_buffers[game_index])
    batch_game = self.switch_action_space(batch_game, game_idx=game_index)
    return batch_game

  def get_iterator_with_jax_parallel(self,):
    """Pre-fetching that includes putting shards on jax devices."""
    queue = collections.deque()

    def enqueue(n):
      for _ in range(n):
        data = [self.sample_batch_for_one_device() for _ in range(self._num_devices)]
        queue.append(jax.device_put_sharded(data, jax.local_devices()))

    enqueue(1)
    while queue:
      yield queue.pop()
      enqueue(1)

  def load(self, *args, **kwargs):
    pass

  def reload_buffer(self, num_buffers):
    pass

  def save(self, *args, **kwargs):
    pass

  def add(self, *args, **kwargs):
    pass

  @property
  def add_count(self,):
    return np.sum([buffer.add_count for buffer in self._all_buffers])

  @property
  def replay_capacity(self):
    return self._replay_capacity

  def reload_data(self):
    pass

