import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax.core as fcore


class TimeoutWrapper:
    _namespace = "_timeout_wrapper"

    def __init__(self, env, limit):
        self._env = env
        self.limit = jnp.array(limit, dtype=jnp.int32)

    def gen_tree(self):
        _tree = self._env.gen_tree()
        _tree = _tree.replace(
            infos=_tree.infos.copy(
                {
                    self._namespace: dict(
                        steps=jnp.array(0),
                        limit=jnp.array(
                            [jnp.nan if self.limit.ndim == 0 else 0, jnp.nan]
                        ),
                    )
                }
            )
        )
        return _tree

    def make(self, *args, **kwargs):
        state = self._env.make(*args, **kwargs)
        state = state.replace(
            infos=state.infos.copy(
                {
                    self._namespace: {
                        "steps": jnp.zeros((), dtype=jnp.int32),
                        "limit": self.limit,
                    }
                }
            )
        )

        return state

    def reset(self, *args, **kwargs):
        @jax.jit
        def _reset(*args, **kwargs):
            state = self._env.reset(*args, **kwargs)

            state = state.replace(
                infos=state.infos.copy(
                    {
                        self._namespace: state.infos[self._namespace].copy(
                            {"steps": jnp.zeros((), dtype=jnp.int32)}
                        )
                    }
                ),
                truncs=jnp.zeros_like(state.truncs),
            )
            return state

        return _reset(*args, **kwargs)

    def step(self, state, action: jax.Array, *args, **kwargs):
        @jax.jit
        def _step(state, action, *args, **kwargs):
            state = self._env.step(state, action, *args, **kwargs)
            _limit = state.infos[self._namespace]["limit"]

            steps = state.infos[self._namespace]["steps"] + 1
            one = jnp.ones_like(state.terms)
            zero = jnp.zeros_like(state.terms)

            truncs = jnp.where(steps >= _limit, one, zero)

            return state.replace(
                infos=state.infos.copy(
                    {
                        self._namespace: state.infos[self._namespace].copy(
                            {"steps": steps}
                        )
                    }
                ),
                truncs=truncs,
            )

        return _step(state, action, *args, **kwargs)

    @property
    def unwrapped(self):
        if hasattr(self._env, "unwrapped"):
            return self._env.unwrapped
        return self._env


class VmapWrapper:
    def __init__(self, env, in_tree, out_tree):
        self._env = env
        self.in_tree = in_tree
        self.out_tree = out_tree

    def make(self, *args, **kwargs):
        return jax.vmap(
            self._env.make,
            out_axes=self.out_tree,
        )(*args, **kwargs)

    def reset(self, keys, env_state, *args, **kwargs):

        return jax.jit(
            jax.vmap(
                self._env.reset,
                in_axes=(0, self.in_tree)
                + tuple([0 for _ in range(len(args) + len(kwargs))]),
                out_axes=self.in_tree,
            )
        )(keys, env_state, *args, **kwargs)

    def step(self, state, action, *args, **kwargs):

        return jax.jit(
            jax.vmap(
                self._env.step,
                in_axes=(self.in_tree, 0)
                + tuple([0 for _ in range(len(args) + len(kwargs))]),
                out_axes=self.in_tree,
            )
        )(state, action, *args, **kwargs)

    @property
    def unwrapped(self):
        if hasattr(self._env, "unwrapped"):
            return self._env.unwrapped
        return self._env
