#!/usr/bin/env python
# coding: utf-8

import os
from datasets import load_dataset

def download_gqa(dataset_name, cache_dir=None):
    """Downloads the GQA dataset.

    Args:
        dataset_name (str): The name of the dataset on Hugging Face Hub (e.g., "lmms-lab/GQA").
        cache_dir (str, optional): Directory to cache the datasets. Defaults to Hugging Face default.
    """
    if cache_dir:
        os.makedirs(cache_dir, exist_ok=True)
        print(f"Using cache directory: {cache_dir}")
    else:
        print("Using default Hugging Face cache directory.")

    print(f"\nDownloading GQA dataset: {dataset_name}...")
    try:
        # Load dataset - this will download if not cached
        # lmms-lab/GQA seems like a processed version, let's try it.
        # It might have specific configurations or splits.
        # Update: Specify configs based on error message (e.g., balanced train/val)
        configs_to_load = ["train_balanced_images", "val_balanced_images"] # Example configs
        datasets = {}
        for config_name in configs_to_load:
            print(f"  Loading config: {config_name}...")
            datasets[config_name] = load_dataset(dataset_name, name=config_name, cache_dir=cache_dir)
            print(f"  Successfully loaded config: {config_name}")
        
        print(f"Successfully downloaded/loaded GQA dataset: {dataset_name} (configs: {list(datasets.keys())})")
        print(f"Dataset info: {datasets}")
        # Example: Accessing splits
        # print(f"Available splits: {list(dataset.keys())}")
        # print(f"Train features: {dataset["train"].features}")
    except Exception as e:
        print(f"Error downloading GQA dataset {dataset_name}: {e}")
        # Check if trust_remote_code is needed for other versions like Graphcore/gqa
        # if "requires you to execute the dataset script" in str(e):
        #     print("Attempting download with trust_remote_code=True...")
        #     # ... (similar error handling as VQA script)

if __name__ == "__main__":
    # Dataset name from search results (choosing lmms-lab version)
    gqa_dataset_name = "lmms-lab/GQA"
    # Specify a cache directory within the project if desired
    # cache_directory = "/home/ubuntu/ecam_project/data/hf_cache"
    cache_directory = None # Use default cache
    
    print(f"Starting download for GQA dataset: {gqa_dataset_name}")
    download_gqa(gqa_dataset_name, cache_dir=cache_directory)
    print("\nGQA dataset download process finished.")

