from flax import nnx
import jax
import numpy as np

from offline.modules.mlp import MLP
from offline.types import FloatArray
from offline.utils.data import ArrayDataLoader


@jax.jit
def compute_batch_mlp_outputs(
    graphdef: nnx.GraphDef[MLP],
    graphstate: nnx.GraphState | nnx.VariableState,
    batch: jax.Array,
):
    model = nnx.merge(graphdef, graphstate)
    outputs = model(batch)
    return outputs


def compute_mlp_outputs(model: MLP, inputs: FloatArray, batch_size: int = 256):
    outputs_list = []
    graphdef, graphstate = nnx.split(model)
    for batch in ArrayDataLoader(
        inputs, batch_size=batch_size, drop_last=False
    ):
        outputs = compute_batch_mlp_outputs(graphdef, graphstate, batch)
        outputs_list.append(outputs)
    return np.concatenate(outputs_list)


def count_parameters(model: nnx.Module):
    params = nnx.state(model)
    return sum((np.prod(x.shape) for x in jax.tree.leaves(params)), start=0)


def default_nnx_rngs(seed: int, **kwargs):
    return nnx.Rngs(seed, params=seed, **kwargs)
