import core
from core import diffusion
from core import distill
from core.ar import ARCore
from core.hybrid_ar_diff import HybridDiffAR
from core.diffusion.mdlm_remask import MDLMRemask
from core.debug_reload import Debug


def get_diffusion(config, tokenizer):
    mode = config.parameterization.name
    if mode == "mdlm":
        return diffusion.MDLM(config, tokenizer)
    elif mode == "distill-mdlm-double-k":
        return distill.DistillMDLMDoubleEveryKCorrect(config, tokenizer)
    else:
        raise ValueError(f"Unknown parameterization `{mode}`")


def get_diffusion_module(config):
    mode = config.parameterization.name
    if mode == "mdlm":
        return diffusion.mdlm
    elif mode == "distill-mdlm-double-k":
        return distill.mdlm_double_dt_correct
    else:
        raise ValueError(f"Unknown parameterization `{mode}`")
