from typing import Any, Dict, Optional, Tuple, Union

import gym
import numpy as np
import torch
import dm_env

from dm_env import specs
from gym import spaces

################################################################################################################
# get_state
#
# Converts the state given by the environment to a tensor of size (in_channel, 10, 10), and then
# unsqueeze to expand along the 0th dimension so the function returns a tensor of size (1, in_channel, 10, 10).
#
# Input:
#   s: current state as numpy array
#
# Output: current state as tensor, permuted to match expected dimensions
#
################################################################################################################
# def get_state(s):
#     return (torch.tensor(s, device=device).permute(2, 0, 1)).unsqueeze(0).float()

# OpenAI gym step format = obs, reward, is_finished, other_info
_GymTimestep = Tuple[np.ndarray, float, bool, Dict[str, Any]]


class MinAtariDMEnv(gym.Env):
    """A wrapper that converts a dm_env.Environment to an OpenAI gym.Env."""

    metadata = {'render.modes': ['human', 'rgb_array']}

    def __init__(self, env: dm_env.Environment):
        self._env = env  # type: dm_env.Environment
        self._last_observation = None  # type: Optional[np.ndarray]
        self.viewer = None
        self.game_over = False  # Needed for Dopamine agents.

    def step(self, action: int) -> _GymTimestep:
        reward, terminated = self._env.act(action)
        self._last_observation = self._env.state()
        if terminated:
            self.game_over = True
        return self._last_observation, reward, terminated, {}

    def reset(self) -> np.ndarray:
        self.game_over = False
        self._env.reset()
        self._last_observation = self._env.state()
        return self._last_observation

    def render(self, mode: str = 'rgb_array') -> Union[np.ndarray, bool]: 
        pass
   #  if self._last_observation is None:
      # raise ValueError('Environment not ready to render. Call reset() first.')

    # if mode == 'rgb_array':
      # return self._last_observation

    # if mode == 'human':
      # if self.viewer is None:
        # # pylint: disable=import-outside-toplevel
        # # pylint: disable=g-import-not-at-top
        # from gym.envs.classic_control import rendering
        # self.viewer = rendering.SimpleImageViewer()
      # self.viewer.imshow(self._last_observation)
   #    return self.viewer.isopen

    @property
    def action_space(self) -> spaces.Discrete:
        return spaces.Discrete(self._env.num_actions())

    @property
    def observation_space(self) -> spaces.Box:
        return spaces.Box(
        low=-float('inf'),
        high=float('inf'),
        shape=self._env.state_shape(),
        dtype=float)

    @property
    def reward_range(self) -> Tuple[float, float]:
        return -float('inf'), float('inf')

    def __getattr__(self, attr):
        """Delegate attribute access to underlying environment."""
        return getattr(self._env, attr)


