import jax
import jax.numpy as jnp
import chex
import numpy as np
from flax import struct
from functools import partial
from typing import Optional, Tuple, Union, Any
from gymnax.environments import environment, spaces
# from brax import envs
# from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper
# import navix as nx


class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env

    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)


class FlattenObservationWrapper(GymnaxWrapper):
    """Flatten the observations of the environment."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    def observation_space(self, params) -> spaces.Box:
        assert isinstance(
            self._env.observation_space(params), spaces.Box
        ), "Only Box spaces are supported for now."
        return spaces.Box(
            low=self._env.observation_space(params).low,
            high=self._env.observation_space(params).high,
            shape=(np.prod(self._env.observation_space(params).shape),),
            dtype=self._env.observation_space(params).dtype,
        )

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, state = self._env.reset(key, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        obs = jnp.reshape(obs, (-1,))
        return obs, state, reward, done, info


@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    timestep: int


class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None, idx: int = 0
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params, idx)
        state = LogEnvState(env_state, 0, 0, 0, 0, 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - done),
            episode_lengths=new_episode_length * (1 - done),
            returned_episode_returns=state.returned_episode_returns * (1 - done)
            + new_episode_return * done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - done)
            + new_episode_length * done,
            timestep=state.timestep + 1,
        )
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["timestep"] = state.timestep
        info["returned_episode"] = done
        return obs, state, reward, done, info


# class BraxGymnaxWrapper:
#     def __init__(self, env_name, backend="positional"):
#         env = envs.get_environment(env_name=env_name, backend=backend)
#         env = EpisodeWrapper(env, episode_length=1000, action_repeat=1)
#         env = AutoResetWrapper(env)
#         self._env = env
#         self.action_size = env.action_size
#         self.observation_size = (env.observation_size,)

#     def reset(self, key, params=None):
#         state = self._env.reset(key)
#         return state.obs, state

#     def step(self, key, state, action, params=None):
#         next_state = self._env.step(state, action)
#         return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {}

#     def observation_space(self, params):
#         return spaces.Box(
#             low=-jnp.inf,
#             high=jnp.inf,
#             shape=(self._env.observation_size,),
#         )

#     def action_space(self, params):
#         return spaces.Box(
#             low=-1.0,
#             high=1.0,
#             shape=(self._env.action_size,),
#         )

# class NavixGymnaxWrapper:
#     def __init__(self, env_name):
#         self._env = nx.make(env_name)

#     def reset(self, key, params=None):
#         timestep = self._env.reset(key)
#         return timestep.observation, timestep

#     def step(self, key, state, action, params=None):
#         timestep = self._env.step(state, action)
#         return timestep.observation, timestep, timestep.reward, timestep.is_done(), {}

#     def observation_space(self, params):
#         return spaces.Box(
#             low=self._env.observation_space.minimum,
#             high=self._env.observation_space.maximum,
#             shape=(np.prod(self._env.observation_space.shape),),
#             dtype=self._env.observation_space.dtype,
#         )

#     def action_space(self, params):
#         return spaces.Discrete(
#             num_categories=self._env.action_space.maximum.item() + 1,
#         )


class ClipAction(GymnaxWrapper):
    def __init__(self, env, low=-1.0, high=1.0):
        super().__init__(env)
        self.low = low
        self.high = high

    def step(self, key, state, action, params=None):
        """TODO: In theory the below line should be the way to do this."""
        # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high)
        action = jnp.clip(action, self.low, self.high)
        return self._env.step(key, state, action, params)


class TransformObservation(GymnaxWrapper):
    def __init__(self, env, transform_obs):
        super().__init__(env)
        self.transform_obs = transform_obs

    def reset(self, key, params=None):
        obs, state = self._env.reset(key, params)
        return self.transform_obs(obs), state

    def step(self, key, state, action, params=None):
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        return self.transform_obs(obs), state, reward, done, info


class TransformReward(GymnaxWrapper):
    def __init__(self, env, transform_reward):
        super().__init__(env)
        self.transform_reward = transform_reward

    def step(self, key, state, action, params=None):
        obs, state, reward, done, info = self._env.step(key, state, action, params)
        return obs, state, self.transform_reward(reward), done, info


class VecEnv(GymnaxWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.reset = jax.vmap(self._env.reset, in_axes=(0, None))
        self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))

@struct.dataclass
class ACGlobalState:
    solved_idx: jnp.ndarray
    env_state: environment.EnvState

class ACVecEnv(GymnaxWrapper):
    """Vectorized environment for AC. Keeps track of solved states."""
    # We could have alternatively kept track of solved states in each environment,
    # and combined the results when logging. 
    # There are two reasons for doing it the current way instead:
    # 1) We can use the global state of solved_idx to pick the next solved / unsolved states.
    # 2) Instead of writing the code that merges the result from each env in wandb.log or other
    # callback functions, we have placed it in one wrapper that now may be used anywhere.
    # This wrapper could also be easily used for other math problems where we may care about
    # solving certain specific examples.
    def __init__(self, env):
        super().__init__(env)
        self.vmap_reset = jax.vmap(self._env.reset, in_axes=(0, None, 0))
        self.vmap_step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))

    def reset(self, key, params=None, idx=None):
        obsv, env_state = self.vmap_reset(key, params, idx)

        # In ppo_ac.py, we only call reset() once and then never again.
        state = ACGlobalState(
            solved_idx=jnp.zeros(len(self._env.init_states), dtype=jnp.bool_),
            env_state=env_state,
        )
        return obsv, state

    def step(self, key, state, action, params=None):
        obs, env_state, reward, done, info = self.vmap_step(
            key, state.env_state, action, params
        )

        def update_solved_idx(solved_idx, idx, terminated):
            mask = jnp.zeros_like(solved_idx, dtype=bool)
            mask = mask.at[idx].set(terminated)
            return jnp.where(mask, True, solved_idx)


        state = ACGlobalState(
            solved_idx=update_solved_idx(state.solved_idx, info["idx"], info["terminated"]),
            env_state=env_state,
        )

        return obs, state, reward, done, info

@struct.dataclass
class NormalizeVecObsEnvState:
    mean: jnp.ndarray
    var: jnp.ndarray
    count: float
    env_state: environment.EnvState


class NormalizeVecObservation(GymnaxWrapper):
    def __init__(self, env):
        super().__init__(env)

    def reset(self, key, params=None):
        obs, state = self._env.reset(key, params)
        state = NormalizeVecObsEnvState(
            mean=jnp.zeros_like(obs),
            var=jnp.ones_like(obs),
            count=1e-4,
            env_state=state,
        )
        batch_mean = jnp.mean(obs, axis=0)
        batch_var = jnp.var(obs, axis=0)
        batch_count = obs.shape[0]

        delta = batch_mean - state.mean
        tot_count = state.count + batch_count

        new_mean = state.mean + delta * batch_count / tot_count
        m_a = state.var * state.count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        state = NormalizeVecObsEnvState(
            mean=new_mean,
            var=new_var,
            count=new_count,
            env_state=state.env_state,
        )

        return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state

    def step(self, key, state, action, params=None):
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )

        batch_mean = jnp.mean(obs, axis=0)
        batch_var = jnp.var(obs, axis=0)
        batch_count = obs.shape[0]

        delta = batch_mean - state.mean
        tot_count = state.count + batch_count

        new_mean = state.mean + delta * batch_count / tot_count
        m_a = state.var * state.count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        state = NormalizeVecObsEnvState(
            mean=new_mean,
            var=new_var,
            count=new_count,
            env_state=env_state,
        )
        return (
            (obs - state.mean) / jnp.sqrt(state.var + 1e-8),
            state,
            reward,
            done,
            info,
        )


@struct.dataclass
class NormalizeVecRewEnvState:
    mean: jnp.ndarray
    var: jnp.ndarray
    count: float
    return_val: float
    env_state: environment.EnvState


class NormalizeVecReward(GymnaxWrapper):
    def __init__(self, env, gamma):
        super().__init__(env)
        self.gamma = gamma

    def reset(self, key, params=None, idx=None):
        obs, state = self._env.reset(key, params, idx)
        batch_count = obs.shape[0]
        state = NormalizeVecRewEnvState(
            mean=0.0,
            var=1.0,
            count=1e-4,
            return_val=jnp.zeros((batch_count,)),
            env_state=state,
        )
        return obs, state

    def step(self, key, state, action, params=None):
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        return_val = state.return_val * self.gamma * (1 - done) + reward

        batch_mean = jnp.mean(return_val, axis=0)
        batch_var = jnp.var(return_val, axis=0)
        batch_count = obs.shape[0]

        delta = batch_mean - state.mean
        tot_count = state.count + batch_count

        new_mean = state.mean + delta * batch_count / tot_count
        m_a = state.var * state.count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        state = NormalizeVecRewEnvState(
            mean=new_mean,
            var=new_var,
            count=new_count,
            return_val=return_val,
            env_state=env_state,
        )
        return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info


# @struct.dataclass
# class ACLogEnvState(LogEnvState):
#     solved_idx: set

# # if done, then env_state.idx should be added to solved_idx.

# class ACLogWrapper(GymnaxWrapper):
#     def __init__(self, env: environment.Environment):
#         super().__init__(env)
    
#     @partial(jax.jit, static_argnums=(0,))
#     def step(
#         self,
#         key: chex.PRNGKey,
#         state: environment.EnvState,
#         action: Union[int, float],
#         params: Optional[environment.EnvParams] = None,
#     ) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
#         # call step_env instead of step as we do not want to reset yet.
#         obs, env_state, reward, done, info = self._env.step_env(
#             key, state.env_state, action, params
#         )
#         new_episode_return = state.episode_returns + reward # running return
#         new_episode_length = state.episode_lengths + 1
#         state = LogEnvState(
#             env_state=env_state,
#             episode_returns=new_episode_return * (1 - done),
#             episode_lengths=new_episode_length * (1 - done),
#             returned_episode_returns=state.returned_episode_returns * (1 - done)
#             + new_episode_return * done, # NOTE: returned_episode_returns = new_episode_return when done, otherwise 0.
#             returned_episode_lengths=state.returned_episode_lengths * (1 - done)
#             + new_episode_length * done,
#             timestep=state.timestep + 1,
#             solved_idx = jax.lax.cond(done, lambda: state.solved_idx.union((env_state.idx,)), state.solved_idx)
#         )
#         info["returned_episode_returns"] = state.returned_episode_returns
#         info["returned_episode_lengths"] = state.returned_episode_lengths
#         info["timestep"] = state.timestep
#         info["returned_episode"] = done
#         info["solved_idx"] = solved_idx
#         return obs, state, reward, done, info
