import torch.nn as nn
import segmentation_models_pytorch as smp

from conf.model import BackboneParams
from src.Backbones.Unet_cold_diffusion import Unet
from src.Backbones.Unet_cold_diffusion_MultiTime import Unet_Cold_Multi_Domain
from src.Backbones.Celeba1 import celeba1_model
from src.Backbones.Celeba3 import adapt_model, freeze_model
from src.UMM_CSGM.Unet_cold_diffusion_UMM_CSGM_General import Unet_UMM_CSGM_General


def get_model(params: BackboneParams) -> nn.Module:
    if params.name == 'unet_cold':
        model = Unet(
            dim                =params.dim,
            time_dim           =params.time_dim,
            init_dim           =params.init_dim,
            out_dim            =params.out_dim,
            dim_mults          =params.dim_mults,
            channels           =params.channels,
            with_time_emb      =params.with_time_emb,
            residual           =params.residual,
            resnet_block_groups=params.resnet_block_groups,
            use_convnext       =params.use_convnext,
            convnext_mult      =params.convnext_mult,
        )
        return model
    elif params.name in ['unet_cold_multi_time', 'celeba_model3']:
        model = Unet_Cold_Multi_Domain(
            dim_per_dom  =params.dimension_per_domain,

            # region additional params
            encoder_split=params.encoder_split,
            encoder_attention_per_block=params.encoder_attention_per_block,
            encoder_time_embedding_per_block=params.encoder_time_embedding_per_block,

            pz_strat=params.pz_strat,
            z_mid_strat=params.z_mid_strat,

            middle_linear_attention=params.middle_linear_attention,
            middle_attention=params.middle_attention,
            middle_time_embedding=params.middle_time_embedding,
            middle_nb_block_following=params.middle_nb_block_following,

            decoder_attention_per_block=params.decoder_attention_per_block,
            decoder_time_embedding_per_block=params.decoder_time_embedding_per_block,
            decoder_split=params.decoder_split,

            use_double_skip=params.use_double_skip,
            # endregion

            dim          =params.dim,
            time_dim     =params.time_dim,
            init_dim     =params.init_dim,
            out_dim      =params.out_dim,
            dim_mults    =tuple(params.dim_mults),
            channels     =params.channels,
            resnet_block_groups=params.resnet_block_groups,
            with_time_emb=params.with_time_emb,
            use_convnext =params.use_convnext,
            convnext_mult=params.convnext_mult,
            residual     =params.residual,
        )
        if params.name == 'celeba_model3':
            model = adapt_model(original=model, gref_model='google/ddpm-ema-celebahq-256', params=params)
            freeze_model(model, params)
        return model
    elif params.name == 'celeba_model1':
        model = celeba1_model(src='google/ddpm-ema-celebahq-256')
        return model
    elif params.name == 'umm_csgm':
        model = Unet_UMM_CSGM_General(
            dim_per_dom  =params.dimension_per_domain,

            # region additional params
            encoder_split=params.encoder_split,
            encoder_attention_per_block=params.encoder_attention_per_block,
            encoder_time_embedding_per_block=params.encoder_time_embedding_per_block,

            pz_strat=params.pz_strat,
            z_mid_strat=params.z_mid_strat,

            middle_linear_attention=params.middle_linear_attention,
            middle_attention=params.middle_attention,
            middle_time_embedding=params.middle_time_embedding,
            middle_nb_block_following=params.middle_nb_block_following,

            decoder_attention_per_block=params.decoder_attention_per_block,
            decoder_time_embedding_per_block=params.decoder_time_embedding_per_block,
            decoder_split=params.decoder_split,

            use_double_skip=params.use_double_skip,
            # endregion

            dim          =params.dim,
            time_dim     =params.time_dim,
            init_dim     =params.init_dim,
            out_dim      =params.out_dim,
            dim_mults    =tuple(params.dim_mults),
            channels     =params.channels,
            resnet_block_groups=params.resnet_block_groups,
            with_time_emb=params.with_time_emb,
            use_convnext =params.use_convnext,
            convnext_mult=params.convnext_mult,
            residual     =params.residual,

            umm_csgm_vanilla_middle=params.umm_csgm_vanilla_middle,
            with_mode_emb=params.with_mode_emb,
            use_pz_m=params.use_pz_m,
        )
        return model
    else:
        raise NotImplementedError
