from torch_geometric.data import Batch

from ltsgns_mp.architectures.prodmp.prodmp import ProDMPPredictor
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict


def build_prodmp(example_input_batch: Batch, simulator_config: ConfigDict, trajectory_length: int, device: str) -> ProDMPPredictor:
    mp_config = simulator_config.mp
    mp_predictor = ProDMPPredictor(
        num_dof=example_input_batch[keys.POSITIONS].shape[1],
        mp_config=mp_config,
        num_time_steps=trajectory_length,
        device=device,
    )
    return mp_predictor
