"""Package for jax NN modules."""

__all__ = ["PFNN", "FNN", "NN"]

from .pfnn import PFNN
from .fnn import FNN
from .nn import NN
