# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MeltingPotEnv as a MultiAgentEnv wrapper to interface with RLLib."""

from typing import Tuple

import dm_env
import dmlab2d
import gym.spaces
import torch
from gym import spaces
from ml_collections import config_dict
import numpy as np
from ray.rllib import algorithms
from ray.rllib.env import multi_agent_env
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID

from examples import utils
from meltingpot.python import substrate
from meltingpot.python.utils.policies import policy

PLAYER_STR_FORMAT = 'player_{index}'


class MeltingPotEnv(multi_agent_env.MultiAgentEnv):
  """An adapter between the Melting Pot substrates and RLLib MultiAgentEnv."""

  def __init__(self, env: dmlab2d.Environment, **kwargs):
    """Initializes the instance.

    Args:
      env: dmlab2d environment to wrap. Will be closed when this wrapper closes.
    """
    self._env = env
    self._num_players = len(self._env.observation_spec())
    self._ordered_agent_ids = [
        PLAYER_STR_FORMAT.format(index=index)
        for index in range(self._num_players)
    ]
    self.invert = kwargs.get("invert")
    # RLLib requires environments to have the following member variables:
    # observation_space, action_space, and _agent_ids
    self._agent_ids = set(self._ordered_agent_ids)
    # RLLib expects a dictionary of agent_id to observation or action,
    # Melting Pot uses a tuple, so we convert
    self.observation_space = self._convert_spaces_tuple_to_dict(
        utils.spec_to_space(self._env.observation_spec()),
        remove_world_observations=False)  # TODO: Changed to False so that we can render world.

    self.use_inventory = kwargs.get("use_inventory")

    # NEW ADDITION
    # Extend observation space with new 'all' observations
    for agent_id in self._ordered_agent_ids:
        self.observation_space[agent_id]['all'] = gym.spaces.Dict(
            spaces={
                aid: self.observation_space[aid]['RGB'] for aid in self._ordered_agent_ids
            } | {
                f'{aid}_INVENTORY': self.observation_space[aid]['INVENTORY'] for aid in self._ordered_agent_ids
            } | {
                f'{aid}_READY_TO_SHOOT': self.observation_space[aid]['READY_TO_SHOOT'] for aid in self._ordered_agent_ids
            } if self.use_inventory else {
                aid: self.observation_space[aid]['RGB'] for aid in self._ordered_agent_ids
            }
        )

    self.action_space = self._convert_spaces_tuple_to_dict(
        utils.spec_to_space(self._env.action_spec()))
    super().__init__()

  def reset(self, **kwargs):
    """See base class."""
    timestep = self._env.reset()
    observations = utils.timestep_to_observations(timestep, remove_world=False)

    # NEW ADDITION:
    # Hack all other agent observations (including self) onto this agent's observation
    for agent_id in self._ordered_agent_ids:
        observations[agent_id]['all'] = {
            aid: observations[aid]['RGB'] for aid in self._ordered_agent_ids
        } | {
            f'{aid}_INVENTORY': observations[aid]['INVENTORY'] for aid in self._ordered_agent_ids
        } | {
            f'{aid}_READY_TO_SHOOT': observations[aid]['READY_TO_SHOOT'] for aid in self._ordered_agent_ids
        } if self.use_inventory else {
            aid: observations[aid]['RGB'] for aid in self._ordered_agent_ids
        }

    infos = {}

    return observations, infos

  def step(self, action):
    """See base class."""
    actions = [action[agent_id] for agent_id in self._ordered_agent_ids]
    timestep = self._env.step(actions)

    # OLD
    # rewards = {
    #     agent_id: timestep.reward[index]
    #     for index, agent_id in enumerate(self._ordered_agent_ids)
    # }

    # # FIXME: Use total reward instead of individual agent rewards to encourage co-operation
    # if self.invert:
    #     total_reward = sum([-timestep.reward[index] for index, agent_id in enumerate(self._ordered_agent_ids)])
    # else:
    total_reward = sum([timestep.reward[index] for index, agent_id in enumerate(self._ordered_agent_ids)])

    rewards = {
        agent_id: total_reward
        for index, agent_id in enumerate(self._ordered_agent_ids)
    }

    done = {'__all__': timestep.last()}
    info = {}

    observations = utils.timestep_to_observations(timestep, remove_world=False)

    # NEW ADDITION:
    # Hack all other agent observations (including self) onto this agent's observation
    for agent_id in self._ordered_agent_ids:
        observations[agent_id]['all'] = {
            aid: observations[aid]['RGB'] for aid in self._ordered_agent_ids
        } | {
            f'{aid}_INVENTORY': observations[aid]['INVENTORY'] for aid in self._ordered_agent_ids
        } | {
            f'{aid}_READY_TO_SHOOT': observations[aid]['READY_TO_SHOOT'] for aid in self._ordered_agent_ids
        } if self.use_inventory else {
            aid: observations[aid]['RGB'] for aid in self._ordered_agent_ids
        }

    return observations, rewards, done, done, info

  def close(self):
    """See base class."""
    self._env.close()

  def get_dmlab2d_env(self):
    """Returns the underlying DM Lab2D environment."""
    return self._env

  # Metadata is required by the gym `Env` class that we are extending, to show
  # which modes the `render` method supports.
  metadata = {'render.modes': ['rgb_array']}

  def render(self, mode: str) -> np.ndarray:
    """Render the environment.

    This allows you to set `record_env` in your training config, to record
    videos of gameplay.

    Args:
        mode (str): The mode to render with (see
        `MeltingPotEnv.metadata["render.modes"]` for supported modes).

    Returns:
        np.ndarray: This returns a numpy.ndarray with shape (x, y, 3),
        representing RGB values for an x-by-y pixel image, suitable for turning
        into a video.
    """
    observation = self._env.observation()
    world_rgb = observation[0]['WORLD.RGB']

    # RGB mode is used for recording videos
    if mode == 'rgb_array':
      return world_rgb
    else:
      return super().render(mode=mode)

  def _convert_spaces_tuple_to_dict(
      self,
      input_tuple: spaces.Tuple,
      remove_world_observations: bool = False) -> spaces.Dict:
    """Returns spaces tuple converted to a dictionary.

    Args:
      input_tuple: tuple to convert.
      remove_world_observations: If True will remove non-player observations.
    """
    return spaces.Dict({
        agent_id: (utils.remove_world_observations_from_space(input_tuple[i])
                   if remove_world_observations else input_tuple[i])
        for i, agent_id in enumerate(self._ordered_agent_ids)
    })


def env_creator(env_config):
  """Outputs an environment for registering."""
  env_config = config_dict.ConfigDict(env_config)
  env = substrate.build(env_config['substrate'], roles=env_config['roles'])
  env = MeltingPotEnv(env, use_inventory='matrix' in env_config['substrate'])
  return env


class RayModelPolicy(policy.Policy):
  """Policy wrapping an RLLib model for inference.

  Note: Currently only supports a single input, batching is not enabled
  """

  def __init__(self,
               model: algorithms.Algorithm,
               policy_id: str = DEFAULT_POLICY_ID) -> None:
    """Initialize a policy instance.

    Args:
      model: An rllib.trainer.Trainer checkpoint.
      policy_id: Which policy to use (if trained in multi_agent mode)
    """
    self._model = model
    self._prev_action = 0
    self._policy_id = policy_id

  def step(self, timestep: dm_env.TimeStep,
           prev_state: policy.State) -> Tuple[int, policy.State]:
    """See base class."""
    observations = {
        key: value
        for key, value in timestep.observation.items()
        if 'WORLD' not in key
    }

    action, state, _ = self._model.compute_single_action(
        observations,
        prev_state,
        policy_id=self._policy_id,
        prev_action=self._prev_action,
        prev_reward=timestep.reward)

    self._prev_action = action
    return action, state

  def initial_state(self) -> policy.State:
    """See base class."""
    self._prev_action = 0
    return self._model.get_policy(self._policy_id).get_initial_state()

  def close(self) -> None:
    """See base class."""
