# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


from typing import Any, Dict, Optional, Tuple, Union
from collections.abc import Iterable
from random import randint
import os

import numpy as np
import gymnasium as gym


# from gymnasium import spaces
# from bsuite.utils import gym_wrapper

import dm_env
from dm_env import specs

from .base import Environment
# from bsuite.environments import base
# from bsuite.experiments.discounting_chain import sweep


import numpy as np

# OpenAI gym step format = obs, reward, is_finished, other_info
_GymTimestep = Tuple[np.ndarray, float, bool, Dict[str, Any]]



class DiscountingChain(Environment):
    """Simple diagnostic discounting challenge.
    
    Observation is two pixels: (context, time_to_live)
    
    Context will only be -1 in the first step, then equal to the action selected in
    the first step. For all future decisions the agent is in a "chain" for that
    action. Reward of +1 come  at one of: 1, 3, 10, 30, 100
    
    However, depending on the seed, one of these chains has a 10% bonus.
    """

    def __init__(self, mapping_seed: Optional[int] = None):
        """Builds the Discounting Chain environment.

        Args:
          mapping_seed: Optional integer, specifies which reward is bonus.
        """
        super().__init__()
        self._episode_len = 100
        self._reward_timestep = [1, 3, 10, 30, 100]
        self._n_actions = len(self._reward_timestep)
        if mapping_seed is None:
            mapping_seed = np.random.randint(0, self._n_actions)
        else:
            mapping_seed = mapping_seed % self._n_actions

        self._rewards = np.ones(self._n_actions)
        self._rewards[mapping_seed] += 0.1

        self._timestep = 0
        self._context = -1

        self.bsuite_num_episodes = 10000 #sweep.NUM_EPISODES

    def _get_observation(self):
        obs = np.zeros(shape=(1, 2), dtype=np.float32)
        obs[0, 0] = self._context
        obs[0, 1] = self._timestep / self._episode_len
        return obs

    def _reset(self) -> dm_env.TimeStep:
        self._timestep = 0
        self._context = -1
        observation = self._get_observation()
        return dm_env.restart(observation)

    def _step(self, action: int) -> dm_env.TimeStep:
        if self._timestep == 0:
            self._context = action

        self._timestep += 1
        if self._timestep == self._reward_timestep[self._context]:
            reward = self._rewards[self._context]
        else:
            reward = 0.0

        observation = self._get_observation()
        if self._timestep == self._episode_len:
            return dm_env.termination(reward=reward, observation=observation)
        return dm_env.transition(reward=reward, observation=observation)

    def observation_spec(self):
        return specs.Array(shape=(1, 2), dtype=np.float32, name="observation")

    def action_spec(self):
        return specs.DiscreteArray(self._n_actions, name="action")

    def _save(self, observation):
        self._raw_observation = (observation * 255).astype(np.uint8)

    @property
    def optimal_return(self):
        # Returns the maximum total reward achievable in an episode.
        return 1.1

    def bsuite_info(self) -> Dict[str, Any]:
        return {}




class MemoryChain(Environment):
    """Simple diagnostic memory challenge.

    Observation is given by n+1 pixels: (context, time_to_live).
    
    Context will only be nonzero in the first step, when it will be +1 or -1 iid
    by component. All actions take no effect until time_to_live=0, then the agent
    must repeat the observations that it saw bit-by-bit.
    """
    def __init__(
        self, memory_length: int, num_bits: int = 1, seed: Optional[int] = 1337
    ):
        """Builds the memory chain environment."""
        super(MemoryChain, self).__init__()
        self._memory_length = memory_length
        self._num_bits = num_bits
        self._rng = np.random.RandomState(seed)

        # Contextual information per episode
        self._timestep = 0
        self._context = self._rng.binomial(1, 0.5, num_bits)
        self._query = self._rng.randint(num_bits)

        # Logging info
        self._total_perfect = 0
        self._total_regret = 0
        self._episode_mistakes = 0

        # bsuite experiment length.
        self.bsuite_num_episodes = 10_000  # Overridden by experiment load().

    def _get_observation(self):
        """Observation of form [time, query, num_bits of context]."""
        obs = np.zeros(shape=(1, self._num_bits + 2), dtype=np.float32)
        # Show the time, on every step.
        obs[0, 0] = 1 - self._timestep / self._memory_length
        # Show the query, on the last step
        if self._timestep == self._memory_length - 1:
            obs[0, 1] = self._query
        # Show the context, on the first step
        if self._timestep == 0:
            obs[0, 2:] = 2 * self._context - 1
        return obs

    def _step(self, action: int) -> dm_env.TimeStep:
        observation = self._get_observation()
        self._timestep += 1

        if self._timestep - 1 < self._memory_length:
            # On all but the last step provide a reward of 0.
            return dm_env.transition(reward=0.0, observation=observation)
        if self._timestep - 1 > self._memory_length:
            raise RuntimeError("Invalid state.")  # We shouldn't get here.

        if action == self._context[self._query]:
            reward = 1.0
            self._total_perfect += 1
        else:
            reward = -1.0
            self._total_regret += 2.0
        return dm_env.termination(reward=reward, observation=observation)

    def _reset(self) -> dm_env.TimeStep:
        self._timestep = 0
        self._episode_mistakes = 0
        self._context = self._rng.binomial(1, 0.5, self._num_bits)
        self._query = self._rng.randint(self._num_bits)
        observation = self._get_observation()
        return dm_env.restart(observation)

    def observation_spec(self):
        return specs.Array(
            shape=(1, self._num_bits + 2), dtype=np.float32, name="observation"
        )

    def action_spec(self):
        return specs.DiscreteArray(2, name="action")

    def _save(self, observation):
        self._raw_observation = (observation * 255).astype(np.uint8)

    def bsuite_info(self):
        return dict(total_perfect=self._total_perfect, total_regret=self._total_regret)


"""bsuite adapter for OpenAI gym run-loops."""

class BsuiteGymWrapper(gym.Env):
  """A wrapper that converts a dm_env.Environment to an OpenAI gym.Env."""

  metadata = {'render.modes': ['human', 'rgb_array']}

  def __init__(self, env_id, **kwargs):

    # print(kwargs)
    if env_id == 'MemoryLength':
        self._env = MemoryChain(memory_length = int(kwargs['memory_length']), num_bits = int(kwargs['num_bits']))
        self.max_episode_steps = kwargs['max_episode_steps'] if 'max_episode_steps' in kwargs else kwargs['memory_length'] + 1

    elif env_id == 'DiscountingChain':
        # print('2' * 1000)
        # print(kwargs)
        self._env = DiscountingChain(int(kwargs['mapping_seed'])) if 'mapping_seed' in kwargs.keys() else None
        self.max_episode_steps = kwargs['max_episode_steps'] if 'max_episode_steps' in kwargs.keys() else 100

    # print('1' * 1000)
        
    self._last_observation = None  # type: Optional[np.ndarray]
    self.viewer = None
    self.game_over = False  # Needed for Dopamine agents.

  def step(self, action: int) -> _GymTimestep:
    timestep = self._env.step(action)
    self._last_observation = timestep.observation
    reward = timestep.reward or 0.
    if timestep.last():
      self.game_over = True
    return timestep.observation.flatten(), reward, timestep.last(), timestep.last(), {}

  def reset(self, seed = None, **kwargs) -> np.ndarray:
    info = {}
    np.random.seed(seed)
    self.game_over = False
    timestep = self._env.reset()
    self._last_observation = timestep.observation
    return timestep.observation.flatten(), info

  def render(self, mode: str = 'rgb_array') -> Union[np.ndarray, bool]:
    if self._last_observation is None:
      raise ValueError('Environment not ready to render. Call reset() first.')

    if mode == 'rgb_array':
      return self._last_observation

    if mode == 'human':
      if self.viewer is None:
        # pylint: disable=import-outside-toplevel
        # pylint: disable=g-import-not-at-top
        from gym.envs.classic_control import rendering
        self.viewer = rendering.SimpleImageViewer()
      self.viewer.imshow(self._last_observation)
      return self.viewer.isopen

  @property
  def action_space(self) -> gym.spaces.Discrete:
    action_spec = self._env.action_spec()  # type: specs.DiscreteArray
    return gym.spaces.Discrete(action_spec.num_values)

  @property
  def observation_space(self) -> gym.spaces.Box:
    obs_spec = self._env.observation_spec()  # type: specs.Array

    ### TO-DO: FIX 
    # if isinstance(obs_spec, specs.BoundedArray):
    #   return gym.spaces.Box(
    #       low=float(obs_spec.minimum),
    #       high=float(obs_spec.maximum),
    #       shape=obs_spec.shape,
    #       dtype=obs_spec.dtype)
    return gym.spaces.Box(
        low=-float('inf'),
        high=float('inf'),
        shape=(obs_spec.shape[1], ),
        dtype=obs_spec.dtype)

  @property
  def reward_range(self) -> Tuple[float, float]:
    reward_spec = self._env.reward_spec()
    if isinstance(reward_spec, specs.BoundedArray):
      return reward_spec.minimum, reward_spec.maximum
    return -float('inf'), float('inf')

  def __getattr__(self, attr):
    """Delegate attribute access to underlying environment."""
    return getattr(self._env, attr)


def space2spec(space: gym.Space, name: Optional[str] = None):
  """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs.

  Box, MultiBinary and MultiDiscrete Gym gym.spaces are converted to BoundedArray
  specs. Discrete OpenAI gym.spaces are converted to DiscreteArray specs. Tuple and
  Dict gym.spaces are recursively converted to tuples and dictionaries of specs.

  Args:
    space: The Gym space to convert.
    name: Optional name to apply to all return spec(s).

  Returns:
    A dm_env spec or nested structure of specs, corresponding to the input
    space.
  """
  if isinstance(space, gym.spaces.Discrete):
    return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name)

  elif isinstance(space, gym.spaces.Box):
    return specs.BoundedArray(shape=space.shape, dtype=space.dtype,
                              minimum=space.low, maximum=space.high, name=name)

  elif isinstance(space, gym.spaces.MultiBinary):
    return specs.BoundedArray(shape=space.shape, dtype=space.dtype, minimum=0.0,
                              maximum=1.0, name=name)

  elif isinstance(space, gym.spaces.MultiDiscrete):
    return specs.BoundedArray(shape=space.shape, dtype=space.dtype,
                              minimum=np.zeros(space.shape),
                              maximum=space.nvec, name=name)

  elif isinstance(space, gym.spaces.Tuple):
    return tuple(space2spec(s, name) for s in space.gym.spaces)

  elif isinstance(space, gym.spaces.Dict):
    return {key: space2spec(value, name) for key, value in space.gym.spaces.items()}

  else:
    raise ValueError('Unexpected gym space: {}'.format(space))



# class AutoencodeMedium(Autoencode):
#     def __init__(self):
#         super().__init__(num_decks=2)
