"""
Loss functions and neural network models for the package.
"""

from .loss import (
   NllFlax,
   MSEFlax,
   AccuracyClassificationFlax,
   PredictionsClassificationFlax
)

from .model import (
   FlaxUpdateParameters,
   TrainState,
   FlaxNet,
   FlaxNetDNN, 
   FlaxNetDNNBN,
   FlaxNetLeNet,
   FlaxNetVGG,
   FlaxNetResNet,
   FlaxNetResNet18_Cust
)

__all__ = [

   # Loss functions
   'NllFlax',
   'MSEFlax',
   'AccuracyClassificationFlax',
   'PredictionsClassificationFlax',
   
   # Model classes
   'FlaxUpdateParameters',
   'TrainState',
   'FlaxNet',
   'FlaxNetDNN', 
   'FlaxNetDNNBN',
   'FlaxNetLeNet',
   'FlaxNetVGG',
   'FlaxNetResNet',
   'FlaxNetResNet18_Cust'
]
