import jax
import jax.numpy as jnp
import chex
from flax import struct
from functools import partial

from typing import Dict, Optional, List, Tuple, Union  # noqa: F401
from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State
from jaxmarl.wrappers.baselines import JaxMARLWrapper


@struct.dataclass
class LogEnvState:
    env_state: State
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int


class LogWrapper(JaxMARLWrapper):
    """Log the episode returns and lengths.
    Based on the JaxMARL LogWrapper, but modified to support auto-resetting wrapped envs.
    """

    def __init__(self, env: MultiAgentEnv, replace_info: bool = False):
        super().__init__(env)
        self.replace_info = replace_info

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
        obs, env_state = self._env.reset(key)
        state = LogEnvState(
            env_state,
            jnp.zeros((self._env.num_agents,)),
            jnp.zeros((self._env.num_agents,)),
            jnp.zeros((self._env.num_agents,)),
            jnp.zeros((self._env.num_agents,)),
        )
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: LogEnvState,
        action: Union[int, float],
    ) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action
        )
        ep_done = done["__all__"]
        new_episode_return = state.episode_returns + \
            self._batchify_floats(reward)
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - ep_done),
            episode_lengths=new_episode_length * (1 - ep_done),
            returned_episode_returns=state.returned_episode_returns *
            (1 - ep_done)
            + new_episode_return * ep_done,
            returned_episode_lengths=state.returned_episode_lengths *
            (1 - ep_done)
            + new_episode_length * ep_done,
        )

        if self.replace_info:
            info = {}
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)

        # for compatibility with auto-resetting wrapped envs
        state = jax.tree.map(
            lambda x, y: jax.lax.select(ep_done, x, y),
            LogEnvState(
                env_state,
                jnp.zeros((self._env.num_agents,)),
                jnp.zeros((self._env.num_agents,)),
                jnp.zeros((self._env.num_agents,)),
                jnp.zeros((self._env.num_agents,)),
            ),
            state)

        return obs, state, reward, done, info
