from ..utils import register_module
from .basic import (
    ModelWrap,
    ModelWrap2,
    Sequential,
    ModuleList,
    Embedding,
    Conv2d,
    ConvTranspose2d,
    Linear,
    AdaptiveAvgPool2d,
    GroupNorm,
    LayerNorm,
    ReLU,
    GELU,
    SiLU,
    Mish,
    QueryKeyAttention,
    MultiHeadAttention,
    TransformEncodeBlock,
    TransformEncode,
    TransformDecodeBlock,
    TransformDecode,
    CNN,
    MLP,
    Mean,
    ResNet,
    Identity,
    BigLittle,
    Dinolet,
    ConvNeXt,
    ResEncoderAkl,
)
from .ocl import (
    SlotAttention,
    NormalInitializ,
    LearntInitializ,
    CartesianPositionalEmbedding2d,
    LearntPositionalEmbedding1d,
    dVAE,
    dVAEGrouped,
    VQVAE,
    VQVAEMultiScale,
    VQVAEGroupedMultiScale,
    VQVAEGrouped,
    Codebook,
    CodebookGrouped,
    LinearPinv2d,
)
from .slate_steve import SLATE, STEVE, TransformDecodeOCL, Parameter
from .steve_multiview import STEVEMultiView
from .slotdiffuz import SlotDiffusionImage, SlotDiffusionVideo
from .unet_slotdiffuz import UNet2dConditionWzy
from .unet_lsd import UNet2dConditionJjd, NoiseSchedJjd, UNet2dCondition
from .utils import find_groups
from .tcc_ocl import TCCOCL
from .mae import SLATEMAE

[register_module(_) for _ in locals().values() if isinstance(_, type)]
