from .robotics_transformer import RoboticsTransformer
from .robotics_lslotformer import RoboticsLSlotFormer

def build_model(params):
    if params.model == "RoboticsTransformer":
        return RoboticsTransformer(
                    resolution=params.resolution,
                    enc_dict=params.enc_dict,
                    act_dec_dict=params.act_dec_dict,
                    loss_dict=params.loss_dict,
                    eps=1e-6,
                )

    if params.model == 'RoboticsLSlotFormer':
        return RoboticsLSlotFormer(
                    resolution=params.resolution,
                    clip_len=params.input_frames,
                    slot_dict=params.slot_dict,
                    enc_dict=params.enc_dict,
                    dec_dict=params.dec_dict,
                    pred_dict=params.pred_dict,
                    rollout_dict=params.rollout_dict,
                    act_dec_dict=params.act_dec_dict,
                    loss_dict=params.loss_dict,
                    eps=1e-6,
                )
    
    