import os
from datasets import load_dataset
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError

# --- Configuration Section ---
# 1. Root directory for local datasets
LOCAL_DATASET_DIR = "./dataset/CardinalOperations"
# 2. Dataset author/organization on Hugging Face
REPO_OWNER = "CardinalOperations"
# 3. List of datasets to download and load
DATASETS_TO_MANAGE = ['IndustryOR', 'MAMO', 'NL4OPT']
# 4. Mirror endpoint for downloading
HF_ENDPOINT = "https://hf-mirror.com"

def download_datasets_with_hf_hub():
    """
    Check and download all required datasets to local using huggingface_hub.snapshot_download.
    """
    print("--- Step 1: Check and download datasets using native API ---")
    
    for d_name in DATASETS_TO_MANAGE:
        local_path = os.path.join(LOCAL_DATASET_DIR, d_name)
        
        # Check if the target folder already exists and is not empty
        if os.path.exists(local_path) and os.listdir(local_path):
            print(f"✅ Dataset '{d_name}' already exists at '{local_path}', skipping download.")
            continue
            
        print(f"⏳ Dataset '{d_name}' does not exist or is empty, downloading from mirror...")
        
        repo_id = f"{REPO_OWNER}/{d_name}"
        
        try:
            # This is the core "native download API"
            snapshot_download(
                repo_id=repo_id,
                repo_type="dataset",       # Explicitly specify it's a dataset
                local_dir=local_path,      # Specify the local path to download to
                endpoint=HF_ENDPOINT,      # Use domestic mirror endpoint
                local_dir_use_symlinks=False # Recommended to set to False on Windows to avoid symlink issues
            )
            print(f"✅ Successfully downloaded '{d_name}' to '{local_path}'.")
        except HfHubHTTPError as e:
            # Catch network-related download errors
            print(f"❌ Failed to download '{d_name}'. Network error: {e}")
            print("   Please check your network connection or if the mirror address is correct.")
        except Exception as e:
            # Catch other unknown errors
            print(f"❌ Unknown error occurred while downloading '{d_name}': {e}")

def load_datasets_from_local():
    """
    Load all datasets from local folders.
    """
    print("\n--- Step 2: Load datasets from local directory ---")
    
    loaded_datasets = {}
    for d_name in DATASETS_TO_MANAGE:
        local_path = os.path.join(LOCAL_DATASET_DIR, d_name)
        
        if not os.path.exists(local_path):
            print(f"⚠️ Warning: Directory '{local_path}' does not exist, cannot load '{d_name}'.")
            continue
            
        try:
            print(f"Loading: {local_path}...")
            # When loading from local folder, if the dataset contains custom loading scripts (.py files),
            # need to set trust_remote_code=True to allow execution of the script.
            ds = load_dataset(local_path, trust_remote_code=True)
            loaded_datasets[d_name] = ds
            print(f"✅ Successfully loaded: {d_name}")
        except Exception as e:
            print(f"❌ Failed to load local dataset {local_path}. Error: {e}")
            
    return loaded_datasets


# --- Main Program Entry ---
if __name__ == "__main__":
    # Step 1: Ensure all datasets are available locally
    download_datasets_with_hf_hub()
    
    # Step 2: Load datasets from local
    all_datasets = load_datasets_from_local()
    
    # Step 3: Use datasets
    if all_datasets:
        print("\n--- All operations completed ---")
        # Example: Print information of the last loaded dataset
        last_ds_name = list(all_datasets.keys())[-1]
        print(f"\nViewing '{last_ds_name}' dataset details:")
        print(all_datasets[last_ds_name])