# from src.models.sample_model import SampleModel
from src.models.model_toy_2D import ToyDiffusion
from src.diffusion.SDE import VESDE, VPSDE
from src.utils.logutils import get_logger

logger = get_logger(__name__)


def get_lightning_model(cfg, denoiser, sde):
    """Models a Lightning Model giving a config object

    Args:
        cfg (config): Whole experiment config
    """
    # If continuing a previous run:
    # if cfg.load_from_checkpoint is not None:
    #     logger.info("Loading model from checkpoint %s", cfg.load_from_checkpoint)
    #     model = SampleModel.load_from_checkpoint(cfg.load_from_checkpoint)
    #     model.config = cfg # NOTE  will this break things?
    # else:
    #     logger.info("Initialising new model")
    if cfg.model.model_name == "model_toy_2D":
        model = ToyDiffusion(cfg, denoiser, sde)  # TODO
        
    return model

def get_sde(cfg):
    if cfg.sde.name == 'VE':
        sde = VESDE(sigma_min=cfg.sde.sigma_min, sigma_max=cfg.sde.sigma_max, N=cfg.model.p_steps)
    if cfg.sde.name == 'VP':
        sde = VPSDE(beta_min=cfg.sde.beta_min, beta_max=cfg.sde.beta_max, N=cfg.model.p_steps)
    return sde
