import jax
import jax.numpy as jnp
from typing import Any, Callable
from modules.architectures.linear_model import linear_module

def make_linear_summary(cfg: dict) -> dict[str, bool | Callable[..., None] | Callable[..., jax.Array]]:

    input_dim = cfg.get("input_dim", 3)
    output_dim = cfg.get("output_dim", 1)
    layers = cfg.get("layers", [128])

    init_fn, apply_fn = linear_module({
        "input_dim": input_dim,
        "output_dim": output_dim,
        "layers": layers
    })

    def init(rng: jax.Array, dtype: jnp.dtype) -> Any:
        return init_fn(rng, dtype)

    def apply(ψ: Any, traj: jax.Array) -> jax.Array:
        return apply_fn(ψ, traj)

    return {
        "has_params": True,
        "init": init,
        "apply": apply,
    }