from .ballet_environment import BalletEnvironment
from typing import Union, Tuple, Optional, Iterable
from numpy.random import Generator
import numpy as np
from gym import spaces
# import gymnasium as gym
import gym

# def seed
# def close 
# def render
# def reset

# env.observation_space
# env.action_space.n
# env.action_space.nvec
# env.max_episode_steps
# env._max_episode_steps


class BalletWrapper(gym.Env):
    def __init__(self, level_name, seed = None) -> None:


        num_dancers, dance_delay = level_name.split("_")
        num_dancers = int(num_dancers)
        dance_delay = int(dance_delay[5:])
        max_steps = 320 if dance_delay == 16 else 1024
        level_args = dict(
            num_dancers=num_dancers,
            dance_delay=dance_delay,
            max_steps=max_steps,
            seed=seed)

        self._env = BalletEnvironment(**level_args)

        self._rewards = []
        self.t = 0

        self.max_episode_steps = max_steps
        obs_spec = self._env.observation_spec
        # print(obs_spec[0])


        # self.observation_space = gym.spaces.Box(
        #     low=-float('inf'),
        #     high=float('inf'),
        #     shape=(obs_spec.shape[1], ),
        #     dtype=np.float32)
        self.observation_space = gym.spaces.Box(
            low=0.0,                     # Minimum value for each pixel
            high=1.0,                    # Maximum value for each pixel (assuming normalized floats in [0, 1])
            shape=(99, 99, 3),           # The shape of the observation space
            dtype=np.float32             # Data type for the space
        )


    # @property
    # def observation_space(self) -> gym.spaces.Box:
    #     """Returns the shape of the observation space of the agent."""
    #     obs_spec = self._env.observation_spec
    #     #return self._env.observation_spec
    #     # return spaces.Tuple((spaces.discrete.Discrete(2), spaces.discrete.Box(-1, 1, shape=(2,))))
    #     return gym.spaces.Box(
    #         low=-float('inf'),
    #         high=float('inf'),
    #         shape=(obs_spec.shape[1], ),
    #         dtype=obs_spec.dtype)

    @property
    def action_space(self):
        """Returns the shape of the action space of the agent."""
        return spaces.discrete.Discrete(8)

    def reset(self, seed, **kwargs):

        np.random.seed(seed)
        self.t = 0
        self._rewards = []

        timestep = self._env.reset()
        obs = timestep.observation
        if isinstance(obs, int):
            obs = np.array([obs, ])
        elif isinstance(obs, Iterable):
            obs = np.array(obs)

        return obs, {}

    def step(self, action):
        """Runs one timestep of the environment's dynamics.
        
        Arguments:
            action {list} -- The to be executed action
        
        Returns:
            {numpy.ndarray} -- Visual observation
            {float} -- (Total) Scalar reward signaled by the environment
            {bool} -- Whether the episode of the environment terminated
            {dict} -- Further episode information (e.g. cumulated reward) retrieved from the environment once an episode completed
        """
        if isinstance(action, Iterable):
            if len(action) == 1:
                action = action[0]
        
        timestep = self._env.step(action)
        
        obs, reward, terminated, info = timestep.observation, timestep.reward, timestep.last(), {}
        self._rewards.append(reward)

        if self.t == self.max_episode_steps:
            terminated = True


        if terminated:
            info = {"reward": sum(self._rewards),
                    "length": len(self._rewards)}

        self.t += 1

        return obs, reward, terminated, info

    def seed(self, seed):
        """Returns the shape of the action space of the agent."""
        self._env._rng = np.random.default_rng(seed)
    
    def render(self):
        """Renders the environment."""
        self._env.render()

    def close(self):
        """Shuts down the environment."""
        self._env.close()
