import jax
import jax.numpy as jnp
from gymnax.environments import environment, spaces
from typing import Tuple, Optional
import chex
from flax import struct

from gymnax.environments.minatar.asterix import MinAsterix, EnvState as AsterixEnvState, EnvParams as AsterixEnvParams
from gymnax.environments.minatar.breakout import MinBreakout, EnvState as BreakoutEnvState, EnvParams as BreakoutEnvParams
from gymnax.environments.minatar.freeway import MinFreeway, EnvState as FreewayEnvState, EnvParams as FreewayEnvParams
from gymnax.environments.minatar.seaquest import MinSeaquest, EnvState as SeaquestEnvState, EnvParams as SeaquestEnvParams
from gymnax.environments.minatar.space_invaders import MinSpaceInvaders, EnvState as SpaceInvadersEnvState, EnvParams as SpaceInvadersEnvParams

@struct.dataclass
class EnvState:
    asterix_state: AsterixEnvState
    breakout_state: BreakoutEnvState
    freeway_state: FreewayEnvState
    seaquest_state: SeaquestEnvState
    space_invaders_state: SpaceInvadersEnvState
    game_mode: int # which game are we currently playing
    
@struct.dataclass
class EnvParams:
    asterix_params: AsterixEnvParams
    breakout_params: BreakoutEnvParams
    freeway_params: FreewayEnvParams
    seaquest_params: SeaquestEnvParams
    space_invaders_params: SpaceInvadersEnvParams
    game_mode: int # which game should we reset to
    
class MinAtarSuite(environment.Environment):
    def __init__(self):
        super().__init__()
        self.action_set = jnp.array([0, 1, 2, 3, 4, 5])
        self.asterix_env = MinAsterix(False)
        self.breakout_env = MinBreakout(False)
        self.freeway_env = MinFreeway(False)
        self.seaquest_env = MinSeaquest(False)
        self.space_invaders_env = MinSpaceInvaders(False)
        
    @property
    def default_params(self) -> EnvParams:
        return EnvParams(
            asterix_state=self.asterix_env.default_params,
            breakout_params=self.breakout_env.default_params,
            freeway_params=self.freeway_env.default_params,
            seaquest_params=self.seaquest_env.default_params,
            space_invaders_params=self.space_invaders_env.default_params,
            game_mode=0,
        )
    
    def step_env(
        self,
        rng: chex.PRNGKey,
        state: EnvState,
        action: int,
        params: EnvParams,
    ) -> Tuple[chex.Array, EnvState, float, bool, dict]:
        asterix_obs, asterix_state, asterix_reward, asterix_done, asterix_info = self.asterix_env.step(rng, state.asterix_state, action)
        asterix_obs, asterix_state, asterix_reward, asterix_done, asterix_info = self.asterix_env.step(rng, state.asterix_state, action)
        asterix_obs, asterix_state, asterix_reward, asterix_done, asterix_info = self.asterix_env.step(rng, state.asterix_state, action)
        asterix_obs, asterix_state, asterix_reward, asterix_done, asterix_info = self.asterix_env.step(rng, state.asterix_state, action)
        asterix_obs, asterix_state, asterix_reward, asterix_done, asterix_info = self.asterix_env.step(rng, state.asterix_state, action)
        
        pass
    
    def reset_env(
        self, rng: chex.PRNGKey, params: EnvParams
    ) -> Tuple[chex.Array, EnvState]:
        pass

    def get_obs(self, state: EnvState) -> chex.Array:
        pass

    def is_terminal(self, state: EnvState, params: EnvParams) -> bool:
        pass

    @property
    def name(self) -> str:
        """Environment name."""
        return "MinAtar-Suite"

    @property
    def num_actions(self) -> int:
        """Number of actions possible in environment."""
        return len(self.action_set)

    def action_space(
        self, params: Optional[EnvParams] = None
    ) -> spaces.Discrete:
        """Action space of the environment."""
        return spaces.Discrete(len(self.action_set))