import torch
from baselines.model_factory import FieldDecoderParams, LatentODEParams, LatentProcessParams, SetEncoderParams

from memKNO.decoder import FieldDecoder
def build_field_decoder(x_grid: torch.Tensor, model_cfg: FieldDecoderParams):
    field_decoder = FieldDecoder(x_grid=x_grid, **model_cfg)
    return field_decoder

from memKNO.network import SetEncoder
def build_set_encoder(model_cfg: SetEncoderParams):
    set_encoder = SetEncoder(**model_cfg)
    return set_encoder

from memKNO.latent import LatentODEfunc, LatentProcess
def build_latent_ode(model_cfg_ode: LatentODEParams, model_cfg_latent: LatentProcessParams):
    ode_func = LatentODEfunc(**model_cfg_ode)
    latent_process = LatentProcess(**model_cfg_latent, ode_func=ode_func)
    return latent_process