from models.ctfno1d import CTFNO1d
from models.encoder_solver_decoder import Encoder_z0_RNN
from models.encoder_solver_decoder import Solver_CTFNO, Decoder
from models.latent_baseline import LatentODE


def get_model(cfg):
    # (1) Requires No Encoding. (ex: Heat, Burgers, ODE, PV, ScalarFlow)
    if cfg.data.name.lower() in ['heat', 'burgers', 'ode', 'pv']:

        model = CTFNO1d(cfg)

        setattr(model, 'encode', False)

    # (2) Requires Encoding. (ex: MuJoCo, PhysioNet, Activity)
    else:
        # (a) Encoder (input to latent)
        encoder = Encoder_z0_RNN(cfg)

        # (b) Diffeq_Solver (solve ODE in the latent space)
        solver = Solver_CTFNO(cfg)

        # (c) Decoder (latent to output)
        decoder = Decoder(cfg)

        # LatentODE (as a wrapper)
        model = LatentODE(cfg, encoder=encoder, solver=solver, decoder=decoder)

        setattr(model, 'encode', True)

    model = model.cuda(cfg.gpu)
    return model 

