JAX models
=========================

JAX-based CATE estimators

.. toctree:: 
    :glob:
    :maxdepth: 2

    T-Learners <generated/catenets.models.jax.tnet.rst>
    R-Learners <generated/catenets.models.jax.rnet.rst>
    X-Learners <generated/catenets.models.jax.xnet.rst>
    Pseudo-Outcome Nets <generated/catenets.models.jax.pseudo_outocome_nets.rst>
    Representation Nets <generated/catenets.models.jax.representation_nets.rst>
    Disentangled Nets <generated/catenets.models.jax.disentangled_nets.rst>
    S-Nets <generated/catenets.models.jax.snet.rst>
    FlexTENet <generated/catenets.models.jax.flextenet.rst>
    OffsetNet <generated/catenets.models.jax.offsetnet.rst>
