import jax
import jax.numpy as jnp

from typing import Any, Tuple, Dict
import chex
from flax import struct
from src.envs.ogc.underspecified_env import UnderspecifiedMultiAgentEnv, EnvState, Observation, Level, EnvParams
from jaxmarl.environments import spaces


@struct.dataclass
class AutoReplayState:
    env_state: EnvState
    level: Level
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int


class AutoReplayWrapper(UnderspecifiedMultiAgentEnv):
    """
    This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
    This is useful for training/rolling out multiple times on the same level.
    """

    def __init__(self, env: UnderspecifiedMultiAgentEnv):
        self._env = env
        self.agents = self._env.agents
        self.num_agents = len(self.agents)
        self.width = self._env.width
        self.height = self._env.height

    @property
    def default_params(self) -> EnvParams:
        return self._env.default_params

    def _batchify_floats(self, x: dict):
        return jnp.stack([x[a] for a in self._env.agents])

    def step_env(
        self,
        rng: chex.PRNGKey,
        state: EnvState,
        action: Dict[str, int],
        params: EnvParams,
    ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
        rng_reset, rng_step = jax.random.split(rng)
        obs_re, env_state_re = self._env.reset_to_level(
            rng_reset, state.level, params)
        obs_st, env_state_st, reward, done, info = self._env.step(
            rng_step, state.env_state, action, params
        )

        ep_done = done["__all__"]
        new_episode_return = state.episode_returns + \
            self._batchify_floats(reward)
        new_episode_length = state.episode_lengths + 1

        env_state = jax.tree.map(lambda x, y: jax.lax.select(
            done["__all__"], x, y), env_state_re, env_state_st)
        obs = jax.tree.map(lambda x, y: jax.lax.select(
            done["__all__"], x, y), obs_re, obs_st)

        state = AutoReplayState(
            env_state=env_state,
            level=state.level,
            episode_returns=new_episode_return * (1 - ep_done),
            episode_lengths=new_episode_length * (1 - ep_done),
            returned_episode_returns=state.returned_episode_returns *
            (1 - ep_done)
            + new_episode_return * ep_done,
            returned_episode_lengths=state.returned_episode_lengths *
            (1 - ep_done)
            + new_episode_length * ep_done,
        )

        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)

        state = jax.tree.map(
            lambda x, y: jax.lax.select(ep_done, x, y),
            AutoReplayState(
                env_state,
                state.level,
                jnp.zeros((self._env.num_agents,)),
                jnp.zeros((self._env.num_agents,)),
                jnp.zeros((self._env.num_agents,)),
                jnp.zeros((self._env.num_agents,)),
            ),
            state)

        return obs, state, reward, done, info

    def reset_env_to_level(
        self,
        rng: chex.PRNGKey,
        level: Level,
        params: EnvParams
    ) -> Tuple[Observation, AutoReplayState]:
        obs, env_state = self._env.reset_to_level(rng, level, params)
        return obs, AutoReplayState(
            env_state=env_state,
            level=level,
            episode_returns=jnp.zeros((self._env.num_agents,)),
            episode_lengths=jnp.zeros((self._env.num_agents,)),
            returned_episode_returns=jnp.zeros((self._env.num_agents,)),
            returned_episode_lengths=jnp.zeros((self._env.num_agents,)),
        )

    def action_space(self, params: EnvParams) -> Any:
        return self._env.action_space(params)

    def observation_space(self, agent: str):
        """Returns the flattened observation space."""
        # Calculate flattened observation shape
        flat_obs_shape = (
            self._env.obs_shape[0] * self._env.obs_shape[1] * self._env.obs_shape[2],)
        return spaces.Box(0, 255, flat_obs_shape)
