# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Compact implementation of an offline multi-task DQN agent in JAX."""

from absl import logging
from multi_task_atari import multi_task_fixed_replay as multi_game_fixed_replay
from multi_task_atari import multi_task_dqn_agent
import gin
import numpy as onp
from multi_task_atari import atari_config


@gin.configurable
class OfflineMultiTaskJaxDQNAgent(multi_task_dqn_agent.MultiTaskJaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None,
               replay_buffer_builder=None,
               preprocess_fn=None,
               network=None,
               game_names=('Asterix',),
               num_devices=1,
               use_single_game_action_space=False,
               override_num_games=-1,
               override_game_index=-1,
               with_task_ids=False):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
      replay_buffer_builder: Callable object that takes "self" as an argument
        and returns a replay buffer to use for training offline. If None,
        it will use the default FixedReplayBuffer.
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    if replay_buffer_builder is not None:
      self._build_replay_buffer = replay_buffer_builder

    self._game_names = game_names
    self._num_devices = num_devices
    self.indices_of_game_actions = [
        atari_config.GAME_TO_FULL_ACTION_SET[game_name]
        for game_name in self._game_names
    ]
    self.game_valid_actions = []
    for idx, game in enumerate(self._game_names):
      temp = onp.zeros((num_actions,))
      temp = temp + 1e-16
      if not use_single_game_action_space:
        for jdx in self.indices_of_game_actions[idx]:
          temp[jdx] = 1.0
      self.game_valid_actions.append(temp)

    self._use_single_game_action_space = use_single_game_action_space
    print ('Multi task offline DQN: ', self._game_names, self._num_devices)
    print ('network in Multi-task offline DQN: ', network)
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer,
        preprocess_fn=preprocess_fn,
        network=network,
        num_games=len(self._game_names),
        num_devices=self._num_devices,
        use_single_game_action_space=use_single_game_action_space,
        override_num_games=override_num_games,
        override_game_index=override_game_index,
        with_task_ids=with_task_ids)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return multi_game_fixed_replay.JaxMultiTaskFixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype,
        game_names=self._game_names,
        num_devices=self._num_devices,
        use_single_game_action_space=self._use_single_game_action_space)

  def reload_data(self):
    self._replay.reload_data()

  def step(self, reward, observation, game_index=None):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation, game_index=game_index)
    state = self.preprocess_fn(self.state)
    game_valid_actions_local = onp.sum(
        onp.array(self.game_valid_actions) * self.game_index[Ellipsis, None],
        axis=0)
    self._rng, self.action = multi_task_dqn_agent.select_action(
        self.network_def, self.online_params, state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn, game_index=self.game_index,
        game_valid_actions=game_valid_actions_local)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()
