# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Environment and wrappers for implementing the multi-task atari env."""

from typing import Any, Sequence

import gin
import gym
from gym import spaces
from multi_task_atari import atari_config
from multi_task_atari import atari_helpers
import numpy as np
from tf_agents.typing import types

class AtariFullActionWrapper(gym.Env):
  """Extend Atari training to use multiple actions."""

  def __init__(self, env, game_name = 'Pong'):
    self._env = env
    self._action_spec = spaces.Discrete(n=len(atari_config.FULL_ACTION_SET))
    self._full_action_to_game_action = atari_config.FULL_ACTION_TO_GAME_ACTION[
        game_name]
    self._game_name = game_name
    super().__init__()

  def action_spec(self):
    return self._action_spec

  @property
  def action_space(self):
    return self._action_spec

  def _get_game_action(self, full_action):
    """Converts full action to game action."""
    return self._full_action_to_game_action[full_action]

  def step(self, action):
    game_action = self._get_game_action(action)
    return self._env.step(game_action)

  def reset(self,):
    return self._env.reset()

  @property
  def game_over(self):
    return self._env.game_over


@gin.configurable
class MultiTaskAtariEnv(gym.Env):
  """Run multi-task atari env."""

  def __init__(self, envs):
    assert envs, 'No environments passed'

    # Environments need to have the same action space before we can actually
    # pass them into a multi-task environment
    action_spec = envs[0].action_space
    for env in envs:
      assert env.action_space == action_spec

    self._envs = envs
    self._env_idx = -1
    self._num_envs = len(envs)
    self.np_random = np.random
    self._sample_env()
    super(MultiTaskAtariEnv, self).__init__()

  def _set_env(self, env_idx):
    self._env_idx = env_idx
    self._env = self._envs[env_idx]

  def _sample_env(self):
    env_idx = self.np_random.randint(self._num_envs)
    self._set_env(env_idx)
    self._env_idx = env_idx

  def reset(self):
    # Don't reset the environment, instead let it run in the current
    # environment for more.
    return (self._env.reset(), self._env_idx)

  def step(self, action):
    return self._env.step(action)

  @property
  def num_envs(self):
    return self._num_envs

  @property
  def env_idx(self):
    return self._env_idx

  def close(self):
    for env in self._envs:
      env.close()

  @property
  def game_over(self):
    return self._env.game_over


@gin.configurable
def create_multi_task_atari_environment(game_names=None,
                                        sticky_actions=True,
                                        use_single_game_action_space=False,
                                        difficulty=None,
                                        game_mode=None):
  """Generates a multi-task Atari environment with the specified game_names

  Args:
    game_names: list(str), the names of the Atari 2600 domain.
    sticky_actions: bool, whether to use sticky_actions as per Machado et al.
    use_single_game_action_space: Using reduced action space for the game.
    difficulty: Game difficulty level.
    game_mode: Atari game mode.

  Returns:
    A Multi-task Atari 2600 environment with some standard preprocessing.
  """
  assert game_names is not None
  list_of_single_game_envs = []
  for game in game_names:
    single_env = atari_helpers.create_atari_environment(
        game,
        sticky_actions=sticky_actions,
        difficulty=difficulty,
        game_mode=game_mode)

    # single_env = atari_lib.create_atari_environment(game)

    if not use_single_game_action_space:
      single_env_wrapped = AtariFullActionWrapper(single_env, game_name=game)
    else:
      print('Not using full action wrapper')
      single_env_wrapped = single_env
    list_of_single_game_envs.append(single_env_wrapped)

  env = MultiTaskAtariEnv(list_of_single_game_envs)
  return env


# def main(argv: Sequence[str]) -> None:
#   env = create_multi_task_atari_environment(
#       game_names=['Asterix', 'Breakout', 'Pong'])

# if __name__ == '__main__':
#   app.run(main)
