from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch


def get_simulator(algorithm_config, env, loading_config, device):
    # calculate an example batch to get the sizes of the network inputs and outputs
    example_batch = None
    for example_batch in env.train_iterator:
        break
    env.train_iterator.refresh_iterator()
    assert example_batch is not None, "The train iterator of the environment must not be empty."
    simulator_config = algorithm_config.simulator
    if simulator_config.name == "step_simulator":
        if algorithm_config.name == "mgn_cnp_iterator":
            # hack to get the correct example batch for the step simulator
            example_batch = StepTrainBatch(example_batch.target_batch)
        from ltsgns_mp.architectures.simulators.step_simulator import StepSimulator
        return StepSimulator(simulator_config, example_batch, loading_config, device)
    if simulator_config.name == "ltsgns_mp_simulator":
        from ltsgns_mp.architectures.simulators.ltsgns_mp_simulator import LTSGNS_MP_Simulator
        return LTSGNS_MP_Simulator(simulator_config, example_batch, algorithm_config.posterior_learner.d_z,
                                   loading_config, device)
    if simulator_config.name == "ltsgns_step_simulator":
        from ltsgns_mp.architectures.simulators.ltsgns_step_simulator import LTSGNS_Step_Simulator
        return LTSGNS_Step_Simulator(simulator_config, example_batch, algorithm_config.posterior_learner.d_z,
                                     loading_config, device)
    if simulator_config.name == "cnp_simulator":
        from ltsgns_mp.architectures.simulators.cnp_simulator import CNPSimulator
        trajectory_length = env.trajectory_length
        return CNPSimulator(simulator_config, example_batch, loading_config, device, trajectory_length)
    if simulator_config.name == "np_simulator":
        from ltsgns_mp.architectures.simulators.np_simulator import NPSimulator
        trajectory_length = env.trajectory_length
        return NPSimulator(simulator_config, example_batch, loading_config, device, trajectory_length)
    else:
        raise NotImplementedError(f"Unknown simulator name: {simulator_config.name}")
