import pdb
import torch.nn as nn
from nflows.transforms.base import CompositeTransform, MultiscaleCompositeTransform
from nflows.transforms.reshape import SqueezeTransform

from .neural_networks import MLP, CNN, T_CNN, GaussianMixtureLSTM,ResidualDecoder,ResidualEncoder,SharedMLP,SharedCNN,SharedT_CNN
from .invertible_networks import SimpleNSFTransform,SharedSimpleNSFTransform
from .generalized_autoencoder import AutoEncoder, WassersteinAutoEncoder, BiGAN, GAN
from .density_estimator import NormalizingFlow, EnergyBasedModel, GaussianMixtureLSTMModel
from . import TwoStepDensityEstimator, GaussianVAE, AdversarialVariationalBayes, ClusterModule, SingleClusterModule, MemEfficientSingleClusterModule, MixedSingleClusteringModule


activation_map = {
    "relu": nn.ReLU,
    "leaky_relu": nn.LeakyReLU, # TODO: allow for non default leak value
    "tanh": nn.Tanh,
    "swish": nn.SiLU,
    "sigmoid": nn.Sigmoid,
    None: None
}
_DEFAULT_ACTIVATION = "relu"

norm_map = {
    "batchnorm": nn.BatchNorm2d,
    "instance": nn.InstanceNorm2d,
    None: None
}

def get_single_clustering_module(cfg, cluster_cfg, clusterer, get_module_fn, run_name):
    modules = []

    for cidx in range(cluster_cfg["num_clusters"]):

        if cluster_cfg["memory_efficient"]:

            module = {
                "get_module_fn": get_module_fn,
                "cfg": cfg,
                "data_dim": cfg["data_dim"], 
                "data_shape": cfg["data_shape"], 
                "train_dataset_size": cfg["train_dataset_size"],
                "cluster_component": cidx,
                "instantiated": False
            }

        else:

            module = get_module_fn(
                cfg,
                data_dim=cfg["data_dim"],
                data_shape=cfg["data_shape"],
                train_dataset_size=cfg["train_dataset_size"]
            )
            module.cluster_component = cidx

        modules.append(module)

    assert not (cluster_cfg["memory_efficient"] and cluster_cfg["param_sharing"])

    if cluster_cfg["memory_efficient"]:
        cluster_module = MemEfficientSingleClusterModule(modules, clusterer, f"./runs/{run_name}/{cluster_cfg['module_save_dir']}")
    elif cluster_cfg["param_sharing"]:
        cluster_module = MixedSingleClusteringModule(modules, clusterer, cfg, cluster_cfg, get_module_fn)
    else:
        cluster_module = SingleClusterModule(modules, clusterer)

    return cluster_module

def get_clustering_module(gae_cfg, de_cfg, shared_cfg, cluster_cfg, clusterer, id_estimates):
    two_step_modules = []

    for cidx in range(cluster_cfg["num_clusters"]):
        two_step_module = get_two_step_module(gae_cfg, de_cfg, shared_cfg, id_estimates[cidx])
        two_step_module.generalized_autoencoder.cluster_component = cidx
        two_step_module.density_estimator.cluster_component = cidx
        two_step_module.cluster_component = cidx
        two_step_modules.append(two_step_module)

    cluster_module = ClusterModule(two_step_modules, clusterer)

    return cluster_module

def get_gae_module(gae_cfg, **kwargs):
    model_to_module_map = {
        "vae": get_vae_module,
        "avb": get_avb_module,
        "ae": get_ae_module,
        "wae": get_wae_module,
    }

    module = model_to_module_map[gae_cfg["model"]](gae_cfg, **kwargs)
  
    lr_scheduler_args = ["max_epochs", "train_dataset_size", "train_batch_size"]
    for arg in lr_scheduler_args:
        if arg not in gae_cfg:
            gae_cfg[arg] = kwargs[arg]
            
    module.set_optimizer(gae_cfg)
   
    return module


def get_de_module(de_cfg, **kwargs):
    model_to_module_map = {
        "vae": get_vae_module,
        "avb": get_avb_module,
        "flow": get_flow_module,
        "ebm": get_ebm_module
    }

    module = model_to_module_map[de_cfg["model"]](de_cfg, **kwargs)

    lr_scheduler_args = ["max_epochs", "train_dataset_size", "train_batch_size"]
    for arg in lr_scheduler_args:
        if arg not in de_cfg:
            de_cfg[arg] = kwargs[arg]

    module.set_optimizer(de_cfg)
    return module


def get_two_step_module(gae_cfg, de_cfg, shared_cfg, id_estimate=None):
    if id_estimate is not None:
        de_cfg["data_dim"] = id_estimate
        gae_cfg["latent_dim"] = id_estimate

    gae_module = get_single_module(gae_cfg, **shared_cfg)

    # HACK: Allows specification of inferred `data_dim` using kwargs in single_main
    shared_cfg_copy = shared_cfg.copy()
    shared_cfg_copy["data_dim"] = de_cfg["data_dim"]
    shared_cfg_copy["data_shape"] = (de_cfg["data_dim"],)

    de_module = get_single_module(de_cfg, **shared_cfg_copy)

    two_step_module = TwoStepDensityEstimator(
        generalized_autoencoder=gae_module,
        density_estimator=de_module
    )

    return two_step_module

def get_single_module(cfg, **kwargs):
    cfg["data_dim"] = kwargs.get("data_dim", None)
    cfg["data_shape"] = kwargs.get("data_shape", None)

    model_to_module_map = {
        "vae": get_vae_module,
        "avb": get_avb_module,
        "ae": get_ae_module,
        "wae": get_wae_module,
        "bigan": get_bigan_module,
        "gan": get_gan_module,
        "flow": get_flow_module,
        "ebm": get_ebm_module,
        "arm": get_arm_module,
    }
    module = model_to_module_map[cfg["model"]](cfg)

    lr_scheduler_args = ["max_epochs", "train_dataset_size", "train_batch_size"]
    for arg in lr_scheduler_args:
        if arg not in cfg:
            cfg[arg] = kwargs[arg]
   
    module.set_optimizer(cfg)

    return module


def get_vae_module(cfg):
    encoder, decoder = get_encoder_decoder(cfg)

    return GaussianVAE(
        latent_dim=cfg["latent_dim"],
        encoder=encoder,
        decoder=decoder,
        base_distribution=cfg["base_distribution"],
        distribution_mean_spacing=cfg["distribution_mean_spacing"],
        num_prior_components=cfg["num_prior_components"],
        conditioning=cfg["conditioning"],
        conditioning_dimension=cfg["conditioning_dimension"],
        **get_data_transform_kwargs(cfg)
    )


def get_avb_module(cfg):
    encoder, decoder = get_encoder_decoder(cfg)
    discriminator = get_discriminator(cfg)

    return AdversarialVariationalBayes(
        latent_dim=cfg["latent_dim"],
        noise_dim=cfg["noise_dim"],
        encoder=encoder,
        decoder=decoder,
        discriminator=discriminator,
        input_sigma=cfg["input_sigma"],
        prior_sigma=cfg["prior_sigma"],
        cnn=True if cfg["encoder_net"] == "cnn" else False,
        **get_data_transform_kwargs(cfg)
    )


def get_ae_module(cfg):
    encoder, decoder = get_encoder_decoder(cfg)

    return AutoEncoder(
        latent_dim=cfg["latent_dim"],
        encoder=encoder,
        decoder=decoder,
        **get_data_transform_kwargs(cfg)
    )


def get_wae_module(cfg):
    encoder, decoder = get_encoder_decoder(cfg)
    discriminator = get_discriminator(cfg)

    return WassersteinAutoEncoder(
        latent_dim=cfg["latent_dim"],
        encoder=encoder,
        decoder=decoder,
        discriminator=discriminator,
        _lambda=cfg["_lambda"],
        sigma=cfg["sigma"],
        base_distribution=cfg["base_distribution"],
        num_mixture_components=cfg["num_mixture_components"],
        distribution_mean_spacing=cfg["distribution_mean_spacing"],
        conditioning=cfg["conditioning"],
        conditioning_dimension=cfg["conditioning_dimension"],
        **get_data_transform_kwargs(cfg)
    )

def get_gan_module(cfg):
    decoder = get_encoder_decoder(cfg)[1]
    discriminator = get_discriminator(cfg)

    return GAN(
        latent_dim=cfg["latent_dim"],
        decoder=decoder,
        discriminator=discriminator,
        wasserstein=cfg.get("wasserstein", True),
        clamp=cfg.get("clamp", 0.01),
        gradient_penalty=cfg.get("gradient_penalty", True),
        _lambda=cfg.get("lambda", 10),
        num_discriminator_steps=cfg.get("num_discriminator_steps", 5),
        **get_data_transform_kwargs(cfg),
        flatten_disc=cfg.get("flatten_disc", True),
        base_distribution=cfg["base_distribution"],
        num_mixture_components=cfg["num_mixture_components"],
        distribution_mean_spacing=cfg["distribution_mean_spacing"],
        conditioning=cfg["conditioning"],
        conditioning_dimension=cfg["conditioning_dimension"],
    )

def get_bigan_module(cfg):
    encoder, decoder = get_encoder_decoder(cfg)
    discriminator = get_discriminator(cfg)

    return BiGAN(
        latent_dim=cfg["latent_dim"],
        encoder=encoder,
        decoder=decoder,
        discriminator=discriminator,
        wasserstein=cfg.get("wasserstein", True),
        clamp=cfg.get("clamp", 0.01),
        gradient_penalty=cfg.get("gradient_penalty", True),
        _lambda=cfg.get("lambda", 10),
        num_discriminator_steps=cfg.get("num_discriminator_steps", 5),
        recon_weight=cfg.get("recon_weight", 1.0),
        **get_data_transform_kwargs(cfg)
    )


def get_flow_module(cfg):
   
    if cfg["transform"] == "simple_nsf":
        transform = SimpleNSFTransform(
            features=cfg["data_dim"],
            hidden_features=cfg["hidden_units"],
            num_layers=cfg["num_layers"],
            num_blocks_per_layer=cfg["num_blocks_per_layer"],
            do_batchnorm=cfg["do_batchnorm"],
            conditioning=cfg["conditioning"],
            conditioning_dimension=cfg["conditioning_dimension"],
            net="mlp"
        )
    
    elif cfg["transform"] == "shared_simple_nsf":
        transform = SharedSimpleNSFTransform(
            features=cfg["data_dim"],
            hidden_features=cfg["hidden_units"],
            share_start=cfg["transform_share_start"],
            share_middle=cfg["transform_share_middle"],
            share_end=cfg["transform_share_end"],
            num_layers=cfg["num_layers"],
            num_blocks_per_layer=cfg["num_blocks_per_layer"],
            do_batchnorm=cfg["do_batchnorm"],
            conditioning=cfg["conditioning"],
            conditioning_dimension=cfg["conditioning_dimension"],
            net="mlp"
        )

    elif cfg["transform"] == "multiscale":
        transform = MultiscaleCompositeTransform(num_transforms=2)

        post_squeeze_dim = cfg["data_shape"][1]//2      # NOTE: Assumes square img
        post_squeeze_shape = (4*cfg["data_shape"][0], post_squeeze_dim, post_squeeze_dim)

        pre_split_nsf_transform = SimpleNSFTransform(
            features=cfg["data_dim"],
            hidden_features=cfg["hidden_units"],
            num_layers=cfg["num_layers"],
            num_blocks_per_layer=cfg["num_blocks_per_layer"],
            net="cnn",
            data_shape=post_squeeze_shape
        )
        pre_split_transform = CompositeTransform(
            transforms=(SqueezeTransform(), pre_split_nsf_transform)
        )

        post_split_shape = transform.add_transform(
            pre_split_transform,
            post_squeeze_shape
        )

        post_split_transform = SimpleNSFTransform(
            features=cfg["data_dim"],
            hidden_features=cfg["hidden_units"],
            num_layers=cfg["num_layers"],
            num_blocks_per_layer=cfg["num_blocks_per_layer"],
            net="cnn",
            data_shape=post_split_shape
        )
        transform.add_transform(
            post_split_transform,
            post_split_shape
        )

    else:
        raise NotImplementedError(f"Transform {cfg['transform']} not implemented")

    return NormalizingFlow(
        dim=cfg["data_dim"],
        transform=transform,
        base_distribution=cfg.get("base_distribution", None),
        num_mixture_components=cfg.get("num_mixture_components", 0),
        distribution_mean_spacing=cfg.get("distribution_mean_spacing", 1),
        conditioning=cfg["conditioning"],
        conditioning_dimension=cfg["conditioning_dimension"],
        **get_data_transform_kwargs(cfg)
    )


def get_ebm_module(cfg):
    _DEFAULT_EBM_ACTIVATION = "swish"

    if cfg["net"] == "mlp":
        energy_func = MLP(
            input_dim=cfg["data_dim"],
            hidden_dims=cfg["energy_func_hidden_dims"],
            output_dim=1,
            activation=activation_map[cfg.get("energy_func_activation", _DEFAULT_EBM_ACTIVATION)],
            spectral_norm=cfg.get("spectral_norm", False),
        )

    elif cfg["net"] == "cnn":
        energy_func = CNN(
            input_channels=cfg["data_shape"][0],
            hidden_channels_list=cfg["energy_func_hidden_channels"],
            output_dim=1,
            kernel_size=cfg["energy_func_kernel_size"],
            stride=cfg["energy_func_stride"],
            image_height=cfg["data_shape"][1],
            do_bn=cfg.get("do_bn", False),
            activation=activation_map[cfg.get("energy_func_activation", _DEFAULT_EBM_ACTIVATION)],
            spectral_norm=cfg.get("spectral_norm", False),
        )

    else:
        raise ValueError(f"Unknown network type {cfg['net']} for EBM")

    if cfg.get("flatten", False):
        x_shape = (cfg["data_dim"],)
    else:
        x_shape = cfg["data_shape"]

    return EnergyBasedModel(
        energy_func=energy_func,
        x_shape=x_shape,
        max_length_buffer=cfg["max_length_buffer"],
        x_lims=cfg["x_lims"],
        ld_steps=cfg["ld_steps"],
        ld_step_size=cfg["ld_step_size"],
        ld_eps_new=cfg["ld_eps_new"],
        ld_sigma=cfg["ld_sigma"],
        ld_grad_clamp=cfg["ld_grad_clamp"],
        loss_alpha=cfg["loss_alpha"],
        **get_data_transform_kwargs(cfg)
    )


def get_arm_module(cfg):
    ar_network = GaussianMixtureLSTM(
        input_size=(1 if len(cfg["data_shape"])==1 else cfg["data_shape"][0]),
        hidden_size=cfg["hidden_size"],
        num_layers=cfg["num_layers"],
        k_mixture=cfg["k_mixture"]
    )

    image_height = None if len(cfg["data_shape"]) == 1 else cfg["data_shape"][1]
    image_length = (
        cfg["data_shape"][0] if len(cfg["data_shape"]) == 1
        else cfg["data_shape"][1]*cfg["data_shape"][2]
    )
    return GaussianMixtureLSTMModel(
        ar_network=ar_network,
        image_height=image_height,
        input_length=image_length,
        **get_data_transform_kwargs(cfg)
    )


def get_encoder_decoder(cfg):
    model = cfg["model"]

    if model == "vae":
        encoder_output_dim = 2*cfg["latent_dim"]
        encoder_output_split_sizes = [cfg["latent_dim"], cfg["latent_dim"]]
    else:
        encoder_output_dim = cfg["latent_dim"]
        encoder_output_split_sizes = None

    if cfg["encoder_net"] == "mlp":
        encoder = MLP(
            input_dim=cfg["data_dim"]+cfg.get("noise_dim", 0),
            hidden_dims=cfg["encoder_hidden_dims"],
            output_dim=encoder_output_dim,
            activation=activation_map[cfg.get("encoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=encoder_output_split_sizes,
            spectral_norm=cfg.get("spectral_norm", False),
            conditioning_dimension=cfg.get("conditioning_dimension", 0)
        )

    elif cfg["encoder_net"] == "shared_mlp":

        encoder = SharedMLP(
            input_dim=cfg["data_dim"]+cfg.get("noise_dim", 0),
            beginning_hidden_dims=cfg["encoder_beginning_hidden_dims"],
            middle_hidden_dims=cfg["encoder_middle_hidden_dims"],
            end_hidden_dims=cfg["encoder_end_hidden_dims"],

            share_start=cfg["encoder_share_start"],
            share_middle=cfg["encoder_share_middle"],
            share_end=cfg["encoder_share_end"],
            
            output_dim=encoder_output_dim,
            activation=activation_map[cfg.get("encoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=encoder_output_split_sizes,
            spectral_norm=cfg.get("spectral_norm", False),
        )

    elif cfg["encoder_net"] == "cnn":
        encoder = CNN(
            input_channels=cfg["data_shape"][0],
            hidden_channels_list=cfg["encoder_hidden_channels"],
            output_dim=encoder_output_dim,
            kernel_size=cfg["encoder_kernel_size"],
            stride=cfg["encoder_stride"],
            padding=cfg.get("encoder_padding", 0),
            shared_module=cfg.get("encoder_shared_module", 0),
            image_height=cfg["data_shape"][1],
            do_bn=cfg.get("do_bn", False),
            activation=activation_map[cfg.get("encoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=encoder_output_split_sizes,
            noise_dim=cfg.get("noise_dim", 0),
            spectral_norm=cfg.get("spectral_norm", False),
            final_activation=cfg.get("encoder_final_activation", None),
            conv_bias=cfg.get("encoder_conv_bias", True),
            final_linear=cfg.get("encoder_final_linear", True),
            conditioning_dimension=cfg.get("conditioning_dimension", 0)
        )
    
    elif cfg["encoder_net"] == "shared_cnn":
        
        encoder = SharedCNN(
            input_channels=cfg["data_shape"][0],
            beginning_hidden_channels=cfg["encoder_beginning_hidden_channels"],
            middle_hidden_channels=cfg["encoder_middle_hidden_channels"],
            end_hidden_channels=cfg["encoder_end_hidden_channels"],
            share_start=cfg["encoder_share_start"],
            share_middle=cfg["encoder_share_middle"],
            share_end=cfg["encoder_share_end"],
            beginning_kernel_size=cfg["encoder_beginning_kernel_size"],
            middle_kernel_size=cfg["encoder_middle_kernel_size"],
            end_kernel_size=cfg["encoder_end_kernel_size"],
            beginning_stride=cfg["encoder_beginning_stride"],
            middle_stride=cfg["encoder_middle_stride"],
            end_stride=cfg["encoder_end_stride"],
            beginning_padding=cfg["encoder_beginning_padding"],
            middle_padding=cfg["encoder_middle_padding"],
            end_padding=cfg["encoder_end_padding"],
            output_dim=encoder_output_dim,
            image_height=cfg["data_shape"][1],
            activation=activation_map[cfg.get("encoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=encoder_output_split_sizes,
            noise_dim=cfg.get("noise_dim", 0),
            spectral_norm=cfg.get("spectral_norm", False),
            final_activation=cfg.get("encoder_final_activation", None),
            conv_bias=cfg.get("encoder_conv_bias", True),
            final_linear=cfg.get("encoder_final_linear", True)
        )

    elif cfg["encoder_net"] == "residual":
        encoder = ResidualEncoder(
            layer_channels=cfg["encoder_layer_channels"], 
            blocks_per_layer = cfg["encoder_blocks_per_layer"],
            input_channels=cfg["data_shape"][0],
            output_dim=cfg["latent_dim"],
            norm=norm_map[cfg.get("norm", "batchnorm")],
            output_split_sizes=encoder_output_split_sizes,
        )

    else:
        raise ValueError(f"Unknown encoder network type {cfg['encoder_net']}")

    if cfg["decoder_net"] == "mlp":
        if model in ["avb", "vae"]:
            decoder_sigma_dim = 1 if cfg["single_sigma"] else cfg["data_dim"]
            decoder_output_dim = cfg["data_dim"] + decoder_sigma_dim
            decoder_output_split_sizes = [cfg["data_dim"], decoder_sigma_dim]
        else:
            decoder_output_dim = cfg["data_dim"]
            decoder_output_split_sizes = None

        decoder = MLP(
            input_dim=cfg["latent_dim"],
            hidden_dims=cfg["decoder_hidden_dims"],
            output_dim=decoder_output_dim,
            activation=activation_map[cfg.get("decoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=decoder_output_split_sizes,
            conditioning_dimension=cfg.get("conditioning_dimension", 0)
        )
    
    elif cfg["encoder_net"] == "shared_mlp":
        if model in ["avb", "vae"]:
            decoder_sigma_dim = 1 if cfg["single_sigma"] else cfg["data_dim"]
            decoder_output_dim = cfg["data_dim"] + decoder_sigma_dim
            decoder_output_split_sizes = [cfg["data_dim"], decoder_sigma_dim]
        else:
            decoder_output_dim = cfg["data_dim"]
            decoder_output_split_sizes = None

        decoder = SharedMLP(
            input_dim=cfg["latent_dim"],
            beginning_hidden_dims=cfg["decoder_beginning_hidden_dims"],
            middle_hidden_dims=cfg["decoder_middle_hidden_dims"],
            end_hidden_dims=cfg["decoder_end_hidden_dims"],

            share_start=cfg["decoder_share_start"],
            share_middle=cfg["decoder_share_middle"],
            share_end=cfg["decoder_share_end"],

            output_dim=decoder_output_dim,
            activation=activation_map[cfg.get("decoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=decoder_output_split_sizes,
            spectral_norm=cfg.get("spectral_norm", False),
        )

    elif cfg["decoder_net"] == "cnn":
        if model in ["avb", "vae"]:
            decoder_sigma_dim = 1 if cfg["single_sigma"] else cfg["data_dim"]
            decoder_output_channels = cfg["data_shape"][0] + decoder_sigma_dim
            decoder_output_split_sizes = [cfg["data_shape"][0], decoder_sigma_dim]
        else:
            decoder_output_channels = cfg["data_shape"][0]
            decoder_output_split_sizes = None

        decoder = T_CNN(
            input_dim=cfg["latent_dim"],
            hidden_channels_list=cfg["decoder_hidden_channels"],
            output_channels=decoder_output_channels,
            kernel_size=cfg["decoder_kernel_size"],
            stride=cfg["decoder_stride"],
            padding=cfg.get("decoder_padding", 0),
            shared_module=cfg.get("decoder_shared_module", 0),
            image_height=cfg["data_shape"][1],
            do_bn=cfg.get("do_bn", False),
            activation=activation_map[cfg.get("decoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=decoder_output_split_sizes,
            single_sigma=cfg.get("single_sigma", False),
            final_activation=activation_map[cfg.get("decoder_final_activation", None)],
            conv_bias=cfg.get("decoder_conv_bias", True),
            norm=norm_map[cfg.get("decoder_norm", None)],
            norm_args=cfg.get("decoder_norm_args", None),
            initial_linear=cfg.get("decoder_initial_linear", True),
            conditioning_dimension=cfg.get("conditioning_dimension", 0)
        )
    
    elif cfg["encoder_net"] == "shared_cnn":
        if model in ["avb", "vae"]:
            decoder_sigma_dim = 1 if cfg["single_sigma"] else cfg["data_dim"]
            decoder_output_channels = cfg["data_shape"][0] + decoder_sigma_dim
            decoder_output_split_sizes = [cfg["data_shape"][0], decoder_sigma_dim]
        else:
            decoder_output_channels = cfg["data_shape"][0]
            decoder_output_split_sizes = None

        decoder = SharedT_CNN(
            input_dim=cfg["latent_dim"],
            beginning_hidden_channels=cfg["decoder_beginning_hidden_channels"],
            middle_hidden_channels=cfg["decoder_middle_hidden_channels"],
            end_hidden_channels=cfg["decoder_end_hidden_channels"],
            share_start=cfg["decoder_share_start"],
            share_middle=cfg["decoder_share_middle"],
            share_end=cfg["decoder_share_end"],
            beginning_kernel_size=cfg["decoder_beginning_kernel_size"],
            middle_kernel_size=cfg["decoder_middle_kernel_size"],
            end_kernel_size=cfg["decoder_end_kernel_size"],
            beginning_stride=cfg["decoder_beginning_stride"],
            middle_stride=cfg["decoder_middle_stride"],
            end_stride=cfg["decoder_end_stride"],
            beginning_padding=cfg["decoder_beginning_padding"],
            middle_padding=cfg["decoder_middle_padding"],
            end_padding=cfg["decoder_end_padding"],
            output_channels=decoder_output_channels,
            image_height=cfg["data_shape"][1],
            activation=activation_map[cfg.get("decoder_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=decoder_output_split_sizes,
            spectral_norm=cfg.get("spectral_norm", False),
            single_sigma=cfg.get("single_sigma", False),
            final_activation=activation_map[cfg.get("decoder_final_activation", None)],
            conv_bias=cfg.get("decoder_conv_bias", True),
            norm=norm_map[cfg.get("decoder_norm", None)],
            norm_args=cfg.get("decoder_norm_args", None),
            initial_linear=cfg.get("decoder_initial_linear", True),
            force_zero_op=cfg.get("decoder_force_zero_op", False)
        )
    
    elif cfg["decoder_net"] == "residual":
        if model in ["avb", "vae"]:
            decoder_sigma_dim = 1 if cfg["single_sigma"] else cfg["data_dim"]
            decoder_output_channels = cfg["data_shape"][0] + decoder_sigma_dim
            decoder_output_split_sizes = [cfg["data_shape"][0], decoder_sigma_dim]
        else:
            decoder_output_channels = cfg["data_shape"][0]
            decoder_output_split_sizes = None

        decoder = ResidualDecoder(
            input_dim=cfg["latent_dim"],
            layer_channels=cfg["decoder_layer_channels"],
            blocks_per_layer=cfg["decoder_blocks_per_layer"],
            output_channels=decoder_output_channels,
            output_split_sizes=decoder_output_split_sizes,
            image_height=cfg["data_shape"][1]
        )

    else:
        raise ValueError(f"Unknown decoder network type {cfg['decoder_net']}")

    return encoder, decoder


def get_discriminator(cfg):
    extra_dims = cfg["data_dim"] if cfg["model"] in ["avb", "bigan"] else 0

    if "disc_net" not in cfg or cfg["disc_net"] == "mlp":
        return MLP(
            input_dim=cfg["latent_dim"]+extra_dims if cfg["model"] != "gan" else cfg["data_dim"],
            hidden_dims=cfg["discriminator_hidden_dims"],
            output_dim=1,
            activation=activation_map[cfg.get("discriminator_activation", _DEFAULT_ACTIVATION)],
            spectral_norm=cfg.get("disc_spectral_norm", False)
        )
    elif cfg["disc_net"] == "shared_mlp":
        return SharedMLP(
            input_dim=cfg["latent_dim"]+extra_dims if cfg["model"] != "gan" else cfg["data_dim"],
            beginning_hidden_dims=cfg["discriminator_beginning_hidden_dims"],
            middle_hidden_dims=cfg["discriminator_middle_hidden_dims"],
            end_hidden_dims=cfg["discriminator_end_hidden_dims"],

            share_start=cfg["discriminator_share_start"],
            share_middle=cfg["discriminator_share_middle"],
            share_end=cfg["discriminator_share_end"],
            
            output_dim=1,
            activation=activation_map[cfg.get("discriminator_activation", _DEFAULT_ACTIVATION)],
            spectral_norm=cfg.get("disc_spectral_norm", False),
        )

    elif cfg["disc_net"] == "shared_cnn":

        return SharedCNN(
            input_channels=cfg["data_shape"][0],
            output_dim=1,
            beginning_hidden_channels=cfg["discriminator_beginning_hidden_channels"],
            middle_hidden_channels=cfg["discriminator_middle_hidden_channels"],
            end_hidden_channels=cfg["discriminator_end_hidden_channels"],
            share_start=cfg["discriminator_share_start"],
            share_middle=cfg["discriminator_share_middle"],
            share_end=cfg["discriminator_share_end"],
            beginning_kernel_size=cfg["discriminator_beginning_kernel_size"],
            middle_kernel_size=cfg["discriminator_middle_kernel_size"],
            end_kernel_size=cfg["discriminator_end_kernel_size"],
            beginning_stride=cfg["discriminator_beginning_stride"],
            middle_stride=cfg["discriminator_middle_stride"],
            end_stride=cfg["discriminator_end_stride"],
            beginning_padding=cfg["discriminator_beginning_padding"],
            middle_padding=cfg["discriminator_middle_padding"],
            end_padding=cfg["discriminator_end_padding"],
            image_height=cfg["data_shape"][1],
            activation=activation_map[cfg.get("discriminator_activation", _DEFAULT_ACTIVATION)],
            noise_dim=cfg.get("noise_dim", 0),
            spectral_norm=cfg.get("spectral_norm", False),
            final_activation=cfg.get("discriminator_final_activation", None),
            conv_bias=cfg.get("discriminator_conv_bias", True),
            norm=norm_map[cfg.get("discriminator_norm", None)],
            norm_args=cfg.get("discriminator_norm_args", None),
            final_linear=cfg.get("discriminator_final_linear", True)
        )

    else:
        return CNN(
            input_channels=cfg["data_shape"][0],
            hidden_channels_list=cfg["disc_hidden_channels"],
            output_dim=1,
            kernel_size=cfg["disc_kernel_size"],
            stride=cfg["disc_stride"],
            padding=cfg.get("discriminator_padding", 0),
            shared_module=cfg.get("discriminator_shared_module", 0),
            image_height=cfg["data_shape"][1],
            activation=activation_map[cfg.get("discriminator_activation", _DEFAULT_ACTIVATION)],
            output_split_sizes=None,
            do_bn=cfg.get("do_bn", False),
            noise_dim=cfg.get("noise_dim", 0),
            spectral_norm=cfg.get("disc_spectral_norm", False),
            final_activation=activation_map[cfg.get("disc_final_activation", None)],
            conv_bias=cfg.get("disc_conv_bias", True),
            norm=norm_map[cfg.get("discriminator_norm", None)],
            norm_args=cfg.get("discriminator_norm_args", None),
            final_linear=cfg.get("discriminator_final_linear", True)
        )


def get_data_transform_kwargs(cfg):
    return {
        "flatten": cfg.get("flatten", False),
        "data_shape": cfg.get("data_shape", None),
        "denoising_sigma": cfg.get("denoising_sigma", None),
        "dequantize": cfg.get("dequantize", False),
        "scale_data": cfg.get("scale_data", False),
        "whitening_transform": cfg.get("whitening_transform", False),
        "logit_transform": cfg.get("logit_transform", False),
        "clamp_samples": cfg.get("clamp_samples", False),
        "logit_transform_alpha": cfg.get("logit_transform_alpha", None),
        "custom_normalize": cfg.get("custom_normalize", False)
    }
