from multinav.deploy.load_data import (
    dataset_postprocess,
    load_dataset,
)

from dlimp.dataset import DLataset

datasets = [
    "gnm_cory_hall",
    "gnm_go_stanford",
    "gnm_recon",
    "gnm_sacson",
    "gnm_scand",
    "gnm_seattle",
    "gnm_tartan_drive",
]
dataset_sizes = {
    "gnm_cory_hall": 148680,
    "gnm_go_stanford": 194429,
    "gnm_recon": 599072,
    "gnm_sacson": 238103,
    "gnm_scand": 31970,
    "gnm_seattle": 7439,
    "gnm_tartan_drive": 17239,
}
total_size = 1236932


def setup_datasets(has_goal: bool, dataset_config: str, discount: float, val: bool = None):
    if dataset_config == "gnm":  # NORMAL DATASETS
        datasets = [
            ("gs://gnm-data-c1/", "gnm_cory_hall", 1.0),
            ("gs://gnm-data-c1/", "gnm_sacson", 1.0),
            ("gs://gnm-data-c1/", "gnm_go_stanford", 1.0),
            ("gs://gnm-data-c1/", "gnm_recon", 1.0),
            ("gs://gnm-data-c1/", "gnm_seattle", 1.0),
            ("gs://gnm-data-c1/", "gnm_tartan_drive", 1.0),
        ]
    elif dataset_config == "sim_pre":  # SUPERVISED SIM DATASET
        datasets = [
            ("gs://gnm-data-c1/", "gnm_cory_hall", 1.0),
            ("gs://gnm-data-c1/", "gnm_sacson", 1.0),
            ("gs://gnm-data-c1/", "gnm_go_stanford", 1.0),
            ("gs://gnm-data-c1/", "gnm_recon", 1.0),
            ("gs://gnm-data-c1/", "gnm_seattle", 1.0),
            ("gs://gnm-data-c1/", "gnm_tartan_drive", 1.0),
            ("gs://locobot_sim/", "supervised", 1.0),
            # ("gs://locobot_sim/", "random_exploration", 1.0),
        ]
    elif dataset_config == "sim_fine":
        datasets = [
            ("gs://gnm-data-c1/", "gnm_cory_hall", 1.0),
            ("gs://gnm-data-c1/", "gnm_sacson", 1.0),
            ("gs://gnm-data-c1/", "gnm_go_stanford", 1.0),
            ("gs://gnm-data-c1/", "gnm_recon", 1.0),
            ("gs://gnm-data-c1/", "gnm_seattle", 1.0),
            ("gs://gnm-data-c1/", "gnm_tartan_drive", 1.0),
            ("gs://locobot_sim/", "supervised", 1.0),
            ("gs://locobot_sim/", "random_exploration", 1.0),
        ]
    elif dataset_config == "rainbow":
        datasets = [
            (<DATA DIR>, "sim_rainbow", 1.0),
        ]

    weights = [w for _, _, w in datasets]
    weights = [w / sum(weights) for w in weights]
    train_dataset = DLataset.sample_from_datasets(
        [
            load_dataset(
                data_dir,
                dataset,
                end_is_crash=False,
                has_goal=has_goal,
                discount=discount,
                val=val,
            )
            for data_dir, dataset, _ in datasets
        ],
        weights=weights,
    )
    return dataset_postprocess(train_dataset)