from core.diffusion.schedule import NamedSchedule
import configs.default as default
from .models import *
import os


def get_train_config(**hparams):
    hparams.setdefault('schedule', NamedSchedule('linear', 1000))
    hparams.setdefault('shift1', True)  # follow original DDPM
    hparams['N'] = hparams['schedule'].N

    config = default.get_train_config(**hparams)
    config.models = ml_collections.ConfigDict()
    config.dataset = default.get_celeba64_config(**hparams)
    if hparams['method'] == 'pred_eps':
        config.models.model = get_ddpm_unet_config(**hparams)
        config.criterion = default.get_dt_dsm_config(**hparams)
        config.wrapper = default.get_dt_wrapper_config(typ='eps', **hparams)
    elif hparams['method'] == 'pred_eps_hes_pretrained':
        hparams['rev_var_type'] = 'optimal'
        config.models.model = get_ddpm_unet_double_pretrained_config(**hparams)
        config.criterion = default.get_dt_dsdm_config(**hparams)
        config.wrapper = default.get_dt_wrapper_config(typ='eps_hes', **hparams)
    elif hparams['method'] == 'pred_eps_epsc_blockcirc_pretrained':
        hparams['rev_var_type'] = 'optimal'
        config.models.model = get_ddpm_unet_double_pretrained_config(**hparams)
        config.criterion = default.get_dt_dsdm_config(**hparams)
        config.wrapper = default.get_dt_wrapper_config(typ='eps_epsc_blockcirc', **hparams)
    elif hparams['method'] == 'pred_eps_hes_blockcirc_pretrained':
        hparams['rev_var_type'] = 'optimal'
        config.models.model = get_ddpm_unet_double_pretrained_config(**hparams)
        config.criterion = default.get_dt_dsdm_config(**hparams)
        config.wrapper = default.get_dt_wrapper_config(typ='eps_hes_blockcirc', **hparams)
    elif hparams['method'] == 'pred_eps_eps2_pretrained':
        hparams['rev_var_type'] = 'optimal'
        config.models.model = get_ddpm_unet_double_pretrained_config(**hparams)
        config.criterion = default.get_dt_dsdm_config(**hparams)
        config.wrapper = default.get_dt_wrapper_config(typ='eps_eps2', **hparams)
    elif hparams['method'] == 'pred_eps_epsc_pretrained':
        hparams['rev_var_type'] = 'optimal'
        config.models.model = get_ddpm_unet_double_pretrained_config(**hparams)
        config.criterion = default.get_dt_dsdm_err_config(**hparams)
        config.wrapper = default.get_dt_wrapper_config(typ='eps_epsc', **hparams)
    else:
        raise NotImplementedError

    #config.evaluator = default.get_train_evaluator_config(**hparams)
    config.evaluator = default.get_train_evaluator_config(**hparams,
                            path=os.path.join(hparams['workspace'], 'train/evaluator/sample2dir/'),
                            period=1000,
                            n_samples=1000, batch_size=128, sample_steps=100,
                            clip_sigma_idx=1,clip_pixel=2)
    return config
