import jax
import jax.numpy as jnp
from typing import Any, Callable

def make_projection_summary(cfg: dict[str, Any])-> dict[str, bool | Callable[..., None] | Callable[..., jax.Array]]:
    def init(rng: jax.Array, dtype: jnp.dtype) -> None:
        return None

    def apply(ψ: Any , traj: jax.Array) -> jax.Array:
        return traj[..., cfg['state_index']]

    return {
        "has_params": False,
        "init": init,
        "apply": apply,
    }
