from ltsgns_mp.envs.data_loader.hydraulic_press_low_res import HydraulicPressLowResDataloader
from ltsgns_mp.envs.data_loader.isaac_sim_dataloader import IsaacSimDataloader
from ltsgns_mp.envs.data_loader.planar_bending_dataloader import PlanarBendingDataloader
from ltsgns_mp.envs.data_loader.sofa_dataloader import SofaDataloader
from ltsgns_mp.envs.data_loader.toy_task_dataloader import ToyTaskDataloader
from ltsgns_mp.envs.env import Env
from ltsgns_mp.util.own_types import ConfigDict


def get_env(config: ConfigDict, train_iterator_config: ConfigDict, evaluation_config: ConfigDict, device: str) -> Env:
    env_name = config.name

    match env_name:
        case "deformable_plate":
            dataloader = SofaDataloader
        case "deformable_plate_v2":
            dataloader = SofaDataloader
        case "tissue_manipulation":
            dataloader = SofaDataloader
        case "cavity_grasping":
            dataloader = SofaDataloader
        case "hydraulic_press_low_res":
            dataloader = HydraulicPressLowResDataloader
        case "sphere_fall":
            dataloader = IsaacSimDataloader
        case "parabolic_toy_task":
            dataloader = ToyTaskDataloader
        case "teddy_fall_nopc":
            dataloader = IsaacSimDataloader
        case "mixed_objects_fall":
            dataloader = IsaacSimDataloader
        case "multi_objects_fall":
            dataloader = IsaacSimDataloader
        case "multi_objects_fall_varied_material":
            dataloader = IsaacSimDataloader
        case "planar_bending":
            dataloader = PlanarBendingDataloader
        # case "multi_decision":
        #     raise NotImplementedError()
        #     # Dataloader = MultiDecisionDataloader
        case _:
            raise ValueError(f"Environment {env_name} unknown.")

    dataloader = dataloader(config=config)
    traj_dict = dataloader.load()
    return Env(config=config, train_iterator_config=train_iterator_config,
               evaluation_config=evaluation_config,
               traj_dict=traj_dict, device=device, )
