from models.diffusion import ddpm
from models.diffusion import unet_conv_v2, unet_conv_v1, unet_nnpd


def load_diff(unet_name, dim=128, dim_cond=768, self_cond=True, layer_channels=None, init_ch=None, final_ch=None, use_embeddings=True, timesteps=1000):
    print('-' * 50)
    print('DDPM NAME:', unet_name)
    print('-' * 50)

    unet = None
    if unet_name == 'unet_v1':
        unet = unet_conv_v1.UNet(
            dim=dim,
            layer_channels=layer_channels,
            init_ch=init_ch,
            final_ch=final_ch,
            self_condition=self_cond,
        )
    elif unet_name == 'unet_v2':
        unet = unet_conv_v2.UNet(
            dim=dim,
            dim_cond=dim_cond,
            ch=1,
            ch_mults=(1, 2, 4,),
            num_heads=4,
            self_condition=self_cond,
            use_embeddings=use_embeddings
        )
    elif unet_name == 'unet_nnpd':
        unet = unet_nnpd.OneDimCNN(
            layer_channels= layer_channels if layer_channels is not None else [2, 64, 128, 256, 512, 256, 128, 64, 2],
            model_dim=dim,
            kernel_size=7
        )
    else:
        print('Unknown DDPM model name:', unet_name)
    model_ddpm = ddpm.DDPM(timesteps=timesteps, model=unet)

    return model_ddpm
