# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import warnings

import torch
from packaging import version
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.jax_utils import (
    _extract_spec,
    _ndarray_to_tensor,
    _object_to_tensordict,
    _tensor_to_ndarray,
    _tensordict_to_object,
    _tree_flatten,
    _tree_reshape,
)
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None

_DEFAULT_CACHE_CLEAR_FREQUENCY = 20


def _get_envs():
    if not _has_brax:
        raise ImportError("BRAX is not installed in your virtual environment.")

    import brax.envs

    return list(brax.envs._envs.keys())


class BraxWrapper(_EnvWrapper):
    """Google Brax environment wrapper.

    Brax offers a vectorized and differentiable simulation framework based on Jax.
    TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
    but computational graphs can still be built on top of the simulated trajectories,
    allowing for backpropagation through the rollout.

    GitHub: https://github.com/google/brax

    Paper: https://arxiv.org/abs/2106.13281

    Args:
        env (brax.envs.base.PipelineEnv): the environment to wrap.
        categorical_action_encoding (bool, optional): if ``True``, categorical
            specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
            otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
            Defaults to ``False``.
        cache_clear_frequency (int, optional): automatically clear JAX's internal
            cache every N steps to prevent memory leaks when using ``requires_grad=True``.
            Defaults to `False` (deactivates automatic cache clearing).

    Keyword Args:
        from_pixels (bool, optional): Not yet supported.
        frame_skip (int, optional): if provided, indicates for how many steps the
            same action is to be repeated. The observation returned will be the
            last observation of the sequence, whereas the reward will be the sum
            of rewards across steps.
        device (torch.device, optional): if provided, the device on which the data
            is to be cast. Defaults to ``torch.device("cpu")``.
        batch_size (torch.Size, optional): the batch size of the environment.
            In ``brax``, this indicates the number of vectorized environments.
            Defaults to ``torch.Size([])``.
        allow_done_after_reset (bool, optional): if ``True``, it is tolerated
            for envs to be ``done`` just after :meth:`reset` is called.
            Defaults to ``False``.

    Attributes:
        available_envs: environments available to build

    Examples:
        >>> import brax.envs
        >>> from torchrl.envs import BraxWrapper
        >>> import torch
        >>> device = "cuda" if torch.cuda.is_available() else "cpu"
        >>> base_env = brax.envs.get_environment("ant")
        >>> env = BraxWrapper(base_env, device=device)
        >>> env.set_seed(0)
        >>> td = env.reset()
        >>> td["action"] = env.action_spec.rand()
        >>> td = env.step(td)
        >>> print(td)
        TensorDict(
            fields={
                action: Tensor(torch.Size([8]), dtype=torch.float32),
                done: Tensor(torch.Size([1]), dtype=torch.bool),
                next: TensorDict(
                    fields={
                        observation: Tensor(torch.Size([87]), dtype=torch.float32)},
                    batch_size=torch.Size([]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(torch.Size([87]), dtype=torch.float32),
                reward: Tensor(torch.Size([1]), dtype=torch.float32),
                state: TensorDict(...)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)
        >>> print(env.available_envs)
        ['acrobot', 'ant', 'fast', 'fetch', ...]

    To take advante of Brax, one usually executes multiple environments at the
    same time. In the following example, we iteratively test different batch sizes
    and report the execution time for a short rollout:

    Examples:
        >>> import torch
        >>> from torch.utils.benchmark import Timer
        >>> device = "cuda" if torch.cuda.is_available() else "cpu"
        >>> for batch_size in [4, 16, 128]:
        ...     timer = Timer('''
        ... env.rollout(100)
        ... ''',
        ...     setup=f'''
        ... import brax.envs
        ... from torchrl.envs import BraxWrapper
        ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}")
        ... env.set_seed(0)
        ... env.rollout(2)
        ... ''')
        ...     print(batch_size, timer.timeit(10))
        4
        env.rollout(100)
        setup: [...]
        310.00 ms
        1 measurement, 10 runs , 1 thread

        16
        env.rollout(100)
        setup: [...]
        268.46 ms
        1 measurement, 10 runs , 1 thread

        128
        env.rollout(100)
        setup: [...]
        433.80 ms
        1 measurement, 10 runs , 1 thread

    One can backpropagate through the rollout and optimize the policy directly:

        >>> import brax.envs
        >>> from torchrl.envs import BraxWrapper
        >>> from tensordict.nn import TensorDictModule
        >>> from torch import nn
        >>> import torch
        >>>
        >>> env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[10], requires_grad=True, cache_clear_frequency=100)
        >>> env.set_seed(0)
        >>> torch.manual_seed(0)
        >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"])
        >>>
        >>> td = env.rollout(10, policy)
        >>>
        >>> td["next", "reward"].mean().backward(retain_graph=True)
        >>> print(policy.module.weight.grad.norm())
        tensor(213.8605)

    """

    git_url = "https://github.com/google/brax"

    @_classproperty
    def available_envs(cls):
        if not _has_brax:
            return []
        return list(_get_envs())

    libname = "brax"

    _lib = None
    _jax = None

    @_classproperty
    def lib(cls):
        if cls._lib is not None:
            return cls._lib

        import brax
        import brax.envs

        cls._lib = brax
        return brax

    @_classproperty
    def jax(cls):
        if cls._jax is not None:
            return cls._jax

        import jax

        cls._jax = jax
        return jax

    def __init__(
        self,
        env=None,
        categorical_action_encoding=False,
        cache_clear_frequency: int | None = None,
        **kwargs,
    ):
        if env is not None:
            kwargs["env"] = env
        self._seed_calls_reset = None
        self._categorical_action_encoding = categorical_action_encoding
        # If user passes None or False, deactivate automatic cache clearing
        if cache_clear_frequency in (False,):
            self._cache_clear_frequency = False
        elif cache_clear_frequency in (None, True):
            self._cache_clear_frequency = _DEFAULT_CACHE_CLEAR_FREQUENCY
        else:
            self._cache_clear_frequency = cache_clear_frequency
        self._step_count = 0
        super().__init__(**kwargs)
        if not self.device:
            warnings.warn(
                f"No device is set for env {self}. "
                f"Setting a device in Brax wrapped environments is strongly recommended."
            )

    def _check_kwargs(self, kwargs: dict):
        brax = self.lib
        if version.parse(brax.__version__) < version.parse("0.10.4"):
            raise ImportError("Brax v0.10.4 or greater is required.")

        if "env" not in kwargs:
            raise TypeError("Could not find environment key 'env' in kwargs.")
        env = kwargs["env"]
        if not isinstance(env, brax.envs.Env):
            raise TypeError("env is not of type 'brax.envs.Env'.")

    def _build_env(
        self,
        env,
        _seed: int | None = None,
        from_pixels: bool = False,
        render_kwargs: dict | None = None,
        pixels_only: bool = False,
        requires_grad: bool = False,
        camera_id: int | str = 0,
        **kwargs,
    ):
        self.from_pixels = from_pixels
        self.pixels_only = pixels_only
        self.requires_grad = requires_grad

        if from_pixels:
            raise NotImplementedError(
                "from_pixels=True is not yest supported within BraxWrapper"
            )
        return env

    def _make_state_spec(self, env: brax.envs.env.Env):  # noqa: F821
        jax = self.jax

        key = jax.random.PRNGKey(0)
        state = env.reset(key)
        state_dict = _object_to_tensordict(state, self.device, batch_size=())
        state_spec = _extract_spec(state_dict).expand(self.batch_size)
        return state_spec

    def _make_specs(self, env: brax.envs.env.Env) -> None:  # noqa: F821
        self.action_spec = Bounded(
            low=-1,
            high=1,
            shape=(
                *self.batch_size,
                env.action_size,
            ),
            device=self.device,
        )
        self.reward_spec = Unbounded(
            shape=[
                *self.batch_size,
                1,
            ],
            device=self.device,
        )
        self.observation_spec = Composite(
            observation=Unbounded(
                shape=(
                    *self.batch_size,
                    env.observation_size,
                ),
                device=self.device,
            ),
            shape=self.batch_size,
        )
        # extract state spec from instance
        state_spec = self._make_state_spec(env)
        self.state_spec["state"] = state_spec
        self.observation_spec["state"] = state_spec.clone()

    def _make_state_example(self):
        jax = self.jax

        key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, self.batch_size.numel())
        state = self._vmap_jit_env_reset(jax.numpy.stack(keys))
        state = _tree_reshape(state, self.batch_size)
        return state

    def _init_env(self) -> int | None:
        jax = self.jax
        self._key = None
        self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset))
        self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step))
        self._state_example = self._make_state_example()

    def _set_seed(self, seed: int | None) -> None:
        jax = self.jax
        if seed is None:
            raise Exception("Brax requires an integer seed.")
        self._key = jax.random.PRNGKey(seed)

    def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
        jax = self.jax

        # generate random keys
        self._key, *keys = jax.random.split(self._key, 1 + self.numel())

        # call env reset with jit and vmap
        state = self._vmap_jit_env_reset(jax.numpy.stack(keys))

        # reshape batch size
        state = _tree_reshape(state, self.batch_size)
        state = _object_to_tensordict(state, self.device, self.batch_size)

        # build result
        state["reward"] = state.get("reward").view(*self.reward_spec.shape)
        state["done"] = state.get("done").view(*self.reward_spec.shape)
        done = state["done"].bool()
        tensordict_out = TensorDict._new_unsafe(
            source={
                "observation": state.get("obs"),
                # "reward": reward,
                "done": done,
                "terminated": done.clone(),
                "state": state,
            },
            batch_size=self.batch_size,
            device=self.device,
        )
        return tensordict_out

    def _step_without_grad(self, tensordict: TensorDictBase):

        # convert tensors to ndarrays
        state = _tensordict_to_object(tensordict.get("state"), self._state_example)
        action = _tensor_to_ndarray(tensordict.get("action"))

        # flatten batch size
        state = _tree_flatten(state, self.batch_size)
        action = _tree_flatten(action, self.batch_size)

        # call env step with jit and vmap
        next_state = self._vmap_jit_env_step(state, action)

        # reshape batch size and convert ndarrays to tensors
        next_state = _tree_reshape(next_state, self.batch_size)
        next_state = _object_to_tensordict(next_state, self.device, self.batch_size)

        # build result
        next_state.set("reward", next_state.get("reward").view(self.reward_spec.shape))
        next_state.set("done", next_state.get("done").view(self.reward_spec.shape))
        done = next_state["done"].bool()
        reward = next_state["reward"]
        tensordict_out = TensorDict._new_unsafe(
            source={
                "observation": next_state.get("obs"),
                "reward": reward,
                "done": done,
                "terminated": done.clone(),
                "state": next_state,
            },
            batch_size=self.batch_size,
            device=self.device,
        )
        return tensordict_out

    def _step_with_grad(self, tensordict: TensorDictBase):

        # convert tensors to ndarrays
        action = tensordict.get("action")
        state = tensordict.get("state")
        qp_keys, qp_values = zip(*state.get("pipeline_state").items())

        # call env step with autograd function
        next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply(
            self, state, action, *qp_values
        )

        # extract done values: we assume a shape identical to reward
        next_done = next_state_nograd.get("done").view(*self.reward_spec.shape)
        next_reward = next_reward.view(*self.reward_spec.shape)

        # merge with tensors with grad function
        next_state = next_state_nograd
        next_state["obs"] = next_obs
        next_state.set("reward", next_reward)
        next_state.set("done", next_done)
        next_done = next_done.bool()
        next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values)))

        # build result
        tensordict_out = TensorDict._new_unsafe(
            source={
                "observation": next_obs,
                "reward": next_reward,
                "done": next_done,
                "terminated": next_done,
                "state": next_state,
            },
            batch_size=self.batch_size,
            device=self.device,
        )
        return tensordict_out

    def _step(
        self,
        tensordict: TensorDictBase,
    ) -> TensorDictBase:

        if self.requires_grad:
            out = self._step_with_grad(tensordict)
        else:
            out = self._step_without_grad(tensordict)

        self._step_count += 1
        if (
            self._cache_clear_frequency
            and (self._step_count % self._cache_clear_frequency) == 0
        ):
            self.clear_cache()

        return out

    def clear_cache(self):
        """Clear JAX's internal cache to prevent memory leaks.

        This method should be called periodically when using requires_grad=True
        to prevent memory accumulation from JAX's internal computation graph.
        """
        if hasattr(self, "jax"):
            try:
                # Clear JAX's compilation cache
                if hasattr(self.jax.jit, "clear_caches"):
                    self.jax.jit.clear_caches()
                # Alternative: clear JAX's internal cache
                if hasattr(self.jax, "clear_caches"):
                    self.jax.clear_caches()
                # Clear JAX's XLA compilation cache if available
                try:
                    import jaxlib

                    if hasattr(jaxlib, "xla_extension"):
                        jaxlib.xla_extension.clear_caches()
                except Exception:
                    pass
            except Exception:
                pass


class BraxEnv(BraxWrapper):
    """Google Brax environment wrapper built with the environment name.

    Brax offers a vectorized and differentiable simulation framework based on Jax.
    TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
    but computational graphs can still be built on top of the simulated trajectories,
    allowing for backpropagation through the rollout.

    GitHub: https://github.com/google/brax

    Paper: https://arxiv.org/abs/2106.13281

    Args:
        env_name (str): the environment name of the env to wrap. Must be part of
            :attr:`~.available_envs`.
        categorical_action_encoding (bool, optional): if ``True``, categorical
            specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
            otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
            Defaults to ``False``.
        cache_clear_frequency (int, optional): automatically clear JAX's internal
            cache every N steps to prevent memory leaks when using ``requires_grad=True``.
            Defaults to `False` (deactivates automatic cache clearing).

    Keyword Args:
        from_pixels (bool, optional): Not yet supported.
        frame_skip (int, optional): if provided, indicates for how many steps the
            same action is to be repeated. The observation returned will be the
            last observation of the sequence, whereas the reward will be the sum
            of rewards across steps.
        device (torch.device, optional): if provided, the device on which the data
            is to be cast. Defaults to ``torch.device("cpu")``.
        batch_size (torch.Size, optional): the batch size of the environment.
            In ``brax``, this indicates the number of vectorized environments.
            Defaults to ``torch.Size([])``.
        allow_done_after_reset (bool, optional): if ``True``, it is tolerated
            for envs to be ``done`` just after :meth:`reset` is called.
            Defaults to ``False``.

    Attributes:
        available_envs: environments available to build

    Examples:
        >>> from torchrl.envs import BraxEnv
        >>> import torch
        >>> device = "cuda" if torch.cuda.is_available() else "cpu"
        >>> env = BraxEnv("ant", device=device)
        >>> env.set_seed(0)
        >>> td = env.reset()
        >>> td["action"] = env.action_spec.rand()
        >>> td = env.step(td)
        >>> print(td)
        TensorDict(
            fields={
                action: Tensor(torch.Size([8]), dtype=torch.float32),
                done: Tensor(torch.Size([1]), dtype=torch.bool),
                next: TensorDict(
                    fields={
                        observation: Tensor(torch.Size([87]), dtype=torch.float32)},
                    batch_size=torch.Size([]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(torch.Size([87]), dtype=torch.float32),
                reward: Tensor(torch.Size([1]), dtype=torch.float32),
                state: TensorDict(...)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)
        >>> print(env.available_envs)
        ['acrobot', 'ant', 'fast', 'fetch', ...]

    To take advante of Brax, one usually executes multiple environments at the
    same time. In the following example, we iteratively test different batch sizes
    and report the execution time for a short rollout:

    Examples:
        >>> import torch
        >>> from torch.utils.benchmark import Timer
        >>> device = "cuda" if torch.cuda.is_available() else "cpu"
        >>> for batch_size in [4, 16, 128]:
        ...     timer = Timer('''
        ... env.rollout(100)
        ... ''',
        ...     setup=f'''
        ... from torchrl.envs import BraxEnv
        ... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}")
        ... env.set_seed(0)
        ... env.rollout(2)
        ... ''')
        ...     print(batch_size, timer.timeit(10))
        4
        env.rollout(100)
        setup: [...]
        310.00 ms
        1 measurement, 10 runs , 1 thread

        16
        env.rollout(100)
        setup: [...]
        268.46 ms
        1 measurement, 10 runs , 1 thread

        128
        env.rollout(100)
        setup: [...]
        433.80 ms
        1 measurement, 10 runs , 1 thread

    One can backpropagate through the rollout and optimize the policy directly:

        >>> from torchrl.envs import BraxEnv
        >>> from tensordict.nn import TensorDictModule
        >>> from torch import nn
        >>> import torch
        >>>
        >>> env = BraxEnv("ant", batch_size=[10], requires_grad=True, cache_clear_frequency=100)
        >>> env.set_seed(0)
        >>> torch.manual_seed(0)
        >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"])
        >>>
        >>> td = env.rollout(10, policy)
        >>>
        >>> td["next", "reward"].mean().backward(retain_graph=True)
        >>> print(policy.module.weight.grad.norm())
        tensor(213.8605)

    """

    def __init__(self, env_name, **kwargs):
        kwargs["env_name"] = env_name
        super().__init__(**kwargs)

    def _build_env(
        self,
        env_name: str,
        **kwargs,
    ) -> brax.envs.env.Env:  # noqa: F821
        if not _has_brax:
            raise ImportError(
                f"brax not found, unable to create {env_name}. "
                f"Consider downloading and installing brax from"
                f" {self.git_url}"
            )
        from_pixels = kwargs.pop("from_pixels", False)
        pixels_only = kwargs.pop("pixels_only", True)
        requires_grad = kwargs.pop("requires_grad", False)
        cache_clear_frequency = kwargs.pop("cache_clear_frequency", False)
        if kwargs:
            raise ValueError("kwargs not supported.")
        self.wrapper_frame_skip = 1
        env = self.lib.envs.get_environment(env_name, **kwargs)
        return super()._build_env(
            env,
            pixels_only=pixels_only,
            from_pixels=from_pixels,
            requires_grad=requires_grad,
            cache_clear_frequency=cache_clear_frequency,
        )

    @property
    def env_name(self):
        return self._constructor_kwargs["env_name"]

    def _check_kwargs(self, kwargs: dict):
        if "env_name" not in kwargs:
            raise TypeError("Expected 'env_name' to be part of kwargs")

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"


class _BraxEnvStep(torch.autograd.Function):
    @staticmethod
    def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values):
        import jax

        # convert tensors to ndarrays
        state_obj = _tensordict_to_object(state_td, env._state_example)
        action_nd = _tensor_to_ndarray(action_tensor)

        # flatten batch size
        state = _tree_flatten(state_obj, env.batch_size)
        action = _tree_flatten(action_nd, env.batch_size)

        # call vjp with jit and vmap
        next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action)

        # reshape batch size
        next_state_reshape = _tree_reshape(next_state, env.batch_size)

        # convert ndarrays to tensors
        next_state_tensor = _object_to_tensordict(
            next_state_reshape, device=env.device, batch_size=env.batch_size
        )

        # save context
        ctx.vjp_fn = vjp_fn
        ctx.next_state = next_state_tensor
        ctx.env = env
        # Mark that backward hasn't been called yet
        ctx._backward_called = False

        return (
            next_state_tensor,  # no gradient
            next_state_tensor["obs"],
            next_state_tensor["reward"],
            *next_state_tensor["pipeline_state"].values(),
        )

    @staticmethod
    def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):
        # Prevent multiple backward calls on the same context
        if hasattr(ctx, "_backward_called") and ctx._backward_called:
            return (None, None, *([None] * len(grad_next_qp_values)))

        ctx._backward_called = True

        pipeline_state = dict(
            zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values)
        )
        none_keys = []

        def _make_none(key, val):
            if val is not None:
                return val
            none_keys.append(key)
            return torch.zeros_like(ctx.next_state.get(("pipeline_state", key)))

        pipeline_state = {
            key: _make_none(key, val) for key, val in pipeline_state.items()
        }
        metrics = ctx.next_state.get("metrics", None)
        if metrics is None:
            metrics = {}
        info = ctx.next_state.get("info", None)
        if info is None:
            info = {}
        grad_next_state_td = TensorDict(
            source={
                "pipeline_state": pipeline_state,
                "obs": grad_next_obs,
                "reward": grad_next_reward,
                "done": torch.zeros_like(ctx.next_state.get("done")),
                "metrics": {k: torch.zeros_like(v) for k, v in metrics.items()},
                "info": {k: torch.zeros_like(v) for k, v in info.items()},
            },
            device=ctx.env.device,
            batch_size=ctx.env.batch_size,
        )
        # convert tensors to ndarrays
        grad_next_state_obj = _tensordict_to_object(
            grad_next_state_td, ctx.env._state_example
        )

        # flatten batch size
        grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size)

        # call vjp to get gradients
        grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat)
        # assert grad_action.device == ctx.env.device

        # reshape batch size
        grad_state = _tree_reshape(grad_state, ctx.env.batch_size)
        grad_action = _tree_reshape(grad_action, ctx.env.batch_size)
        # assert grad_action.device == ctx.env.device

        # convert ndarrays to tensors
        grad_state_qp = _object_to_tensordict(
            grad_state.pipeline_state,
            device=ctx.env.device,
            batch_size=ctx.env.batch_size,
        )
        grad_action = _ndarray_to_tensor(grad_action).to(ctx.env.device)
        grad_state_qp = {
            key: val if key not in none_keys else None
            for key, val in grad_state_qp.items()
        }
        grads = (grad_action, *grad_state_qp.values())

        # Clean up context to prevent memory leaks
        try:
            # Clear JAX VJP function reference
            del ctx.vjp_fn
        except AttributeError:
            pass
        try:
            # Clear stored tensors
            del ctx.next_state
        except AttributeError:
            pass
        try:
            # Clear environment reference
            del ctx.env
        except AttributeError:
            pass
        try:
            # Clear the backward flag
            del ctx._backward_called
        except AttributeError:
            pass

        return (None, None, *grads)
