from .base_cd import BaseCDVAE, CDVAEDataParallel

__all__ = [
    "select_architecture", 
    "BaseCDVAE",
    "CDVAEDataParallel", 
]


def select_architecture(arch, n_classes, input_channels, **kwargs):

    if arch.upper() in ["CONV", "CONVOLUTIONAL",  "RES", "RESIDUAL"]:
        from .conv import ConvCDVAE
        return ConvCDVAE(
            n_classes=n_classes,
            input_channels=input_channels, 
            layer_type=arch, 
            **kwargs
        )
    
    else:
        raise ValueError