from collections.abc import Callable
from typing import NamedTuple

import jax
from mujoco_playground._src.wrapper import Wrapper


class TaskParams(NamedTuple):
    mass_scale: jax.Array  # float
    length_scale: jax.Array  # float


class BraxMultiTaskWrapper(Wrapper):
    """Samples a random task for each episode."""

    def __init__(self, env, task_sampler: Callable[[jax.Array], TaskParams]):
        """
        Args:
            env: The base environment. This should not be vmapped yet.
            task_sampler: Function that samples task parameters given a PRNGKey. Should
                be jittable.
        """
        super().__init__(env)
        self.task_sampler = task_sampler

    def reset(self, rng):
        rng, task_key = jax.random.split(rng)
        task_params = self.task_sampler(task_key)
        return self.env.reset(rng, task_params=task_params)

    def step(self, state, action):
        return self.env.step(state, action)
