from typing import Any

import jax


def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any:
    return jax.tree_map(
        lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]),
        tree
    )


def tree_unchunk(tree: Any, axis: int = 0) -> Any:
    return jax.tree_map(
        lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]),
        tree
    )