import os
from datasets import load_dataset, DatasetDict

#  python3 mllmsd/datamodules/download-datasets.py 
#0. Config
save_dir = "/XXXX-5/home-XXXX-3/data/MSD/datasets/"
os.makedirs(save_dir, exist_ok=True)


# #1. Single
repos = [
    "lmms-lab/MMVet",
    "lmms-lab/POPE",
    "lmms-lab/HallusionBench",
]

for repo in repos:
    if repo == "lmms-lab/LiveBench":
        datasets = load_dataset(repo, '2024-05')
    else:
        datasets = load_dataset(repo)
        
    dataset_name = repo.split("/")[-1]
    save_path = os.path.join(save_dir, dataset_name)

    for split, dataset in datasets.items():
        print(dataset.features)
        break

    if repo == "lmms-lab/llava-bench-in-the-wild":
        datasets = DatasetDict({
            'test': datasets['train']
        })
    elif repo == "lmms-lab/HallusionBench":
        datasets = DatasetDict({
            'test': datasets['image']
        })
        
    datasets.save_to_disk(save_path)
    

#2. Multi image
repo = "lmms-lab/LLaVA-NeXT-Interleave-Bench"
datasets = load_dataset(repo, "in_domain")

dataset_openended = [
    "QBench",
    "NLVR2_Mantis",
    "OCR-VQA",
]

for split, dataset in datasets.items():
    print(dataset.features)
    # print("image: ", type(dataset[0]['image']))
    break

# get open-ended datasets
datasets = datasets.filter(lambda x: x['sub_task'] in dataset_openended)
for dataset_name in dataset_openended:
    # filter 'sub_task' with dataset_name
    save_path = os.path.join(save_dir, dataset_name)
    sub_datasets = datasets.filter(lambda x: x['sub_task'] == dataset_name)
    sub_datasets.save_to_disk(save_path)