import os
from datasets import load_dataset, DatasetDict

#  python3 mllmsd/datamodules/download-datasets.py 
#0. Config
save_dir = "/XXXX-5/home-XXXX-3/data/MSD/datasets/"
os.makedirs(save_dir, exist_ok=True)


#1. Single
repos = [
    "lmms-lab/DC100_EN",
    "lmms-lab/llava-bench-in-the-wild",
    "RekaAI/VibeEval",
    "lmms-lab/LiveBench",
]

for repo in repos:
    if repo == "lmms-lab/LiveBench":
        datasets = load_dataset(repo, '2024-05')
    else:
        datasets = load_dataset(repo)
        
    dataset_name = repo.split("/")[-1]
    save_path = os.path.join(save_dir, dataset_name)

    for split, dataset in datasets.items():
        print(dataset.features)
        break

    if repo == "lmms-lab/llava-bench-in-the-wild":
        datasets = DatasetDict({
            'test': datasets['train']
        })
        
    datasets.save_to_disk(save_path)
    

#2. Multi image
repo = "lmms-lab/LLaVA-NeXT-Interleave-Bench"
datasets = load_dataset(repo, "in_domain")

dataset_openended = [
    'Spot-the-Diff', 
    'Birds-to-Words', 
    'CLEVR-Change', 
    'HQ-Edit', 
    'MagicBrush', 
    'IEdit', 
    'AESOP',
    'FlintstonesSV', 
    'PororoSV',
    'VIST', 
    'WebQA', 
]

for split, dataset in datasets.items():
    print(dataset.features)
    # print("image: ", type(dataset[0]['image']))
    break

# get open-ended datasets
datasets = datasets.filter(lambda x: x['sub_task'] in dataset_openended)
for dataset_name in dataset_openended:
    # filter 'sub_task' with dataset_name
    save_path = os.path.join(save_dir, dataset_name)
    sub_datasets = datasets.filter(lambda x: x['sub_task'] == dataset_name)
    sub_datasets.save_to_disk(save_path)