from ltsgns_mp.envs.train_iterator.abstract_train_iterator import AbstractTrainIterator
from ltsgns_mp.util.own_types import ConfigDict
from ltsgns_mp.envs.trajectory_collection import TrajectoryCollection


def get_train_iterator(config: ConfigDict, train_trajs: TrajectoryCollection, device: str) -> AbstractTrainIterator:
    if config.name == "step_iterator":
        from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainIterator
        return StepTrainIterator(config, train_trajs, device)
    elif config.name == "trajectory_iterator":
        from ltsgns_mp.envs.train_iterator.trajectory_train_iterator import TrajectoryTrainIterator
        return TrajectoryTrainIterator(config, train_trajs, device)
    elif config.name == "ltsgns_step_iterator":
        from ltsgns_mp.envs.train_iterator.ltsgns_step_train_iterator import LTSGNSStepTrainIterator
        return LTSGNSStepTrainIterator(config, train_trajs, device)
    elif config.name == "cnp_iterator":
        from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainIterator
        return CNPTrainIterator(config, train_trajs, device)
    else:
        raise ValueError(f"Train Iterator {config.name} unknown.")
