import os
import tarfile
import numpy as np

# from .misc_data_util import transforms as trans
# from .misc_data_util.url_save import save
# from zipfile import ZipFile




def load_dataset(dataset_name, batch_size, data_path=""):
    dataset_name = dataset_name.lower()  # cast dataset_name to lower case
    if "h2c" in dataset_name:
        from data.datasets.heterogeneous_2d_clusters import create_heterogeneous_2d_dataloader
        train_loader, test_loader = create_heterogeneous_2d_dataloader(batch_size=batch_size, dataset_name=dataset_name)
    else:
        raise Exception("Dataset name not found.")

    return train_loader, test_loader


def load_train_dataset(dataset_name, batch_size, data_path=""):
    if "csi" in dataset_name:
        from data.datasets.csi import create_csi_dataloaders
        train_loader = create_csi_dataloaders(batch_size=batch_size)


    return train_loader