import jax
import jax.numpy as jnp
import jax.random as jrnd
from typing import Optional


class VmapWrapper:
    def __init__(self, algo, in_trees, out_tree, nums: Optional[int] = None):
        self._algo = algo
        self.in_trees = in_trees
        self.out_tree = out_tree
        self.nums = nums

    def make(self, *args, **kwargs):
        return jax.vmap(
            self._algo.make,
            out_axes=self.out_tree,
            axis_size=None if self.nums is None else self.nums,
        )(*args, **kwargs)

    def make_action(self, *args, **kwargs):
        return jax.jit(
            jax.vmap(
                self._algo.make_action,
                in_axes=self.in_trees["make_action"],
            )
        )(*args, **kwargs)

    def update(self, state, *args, **kwargs):
        return jax.jit(
            jax.vmap(
                self._algo.update,
                in_axes=self.in_trees["update"],
            )
        )(state, *args, **kwargs)

    @property
    def unwrapped(self):
        if hasattr(self._algo, "unwrapped"):
            return self._algo.unwrapped
        return self._algo
