from datasets import load_dataset 

def get_dataset(
    dataset_name, 
    subset_name, 
    split = None,
) : 
    if split : 
        dataset = load_dataset(
            dataset_name, 
            subset_name, 
            split=split, 
        )
    else : 
        dataset = load_dataset(
            dataset_name, 
            subset_name, 
        )

    classes = [c.replace("_", " ") for c in dataset["train"].features["label"].names]
    return dataset, classes





