#!/usr/bin/env python3
"""
Comprehensive C4 dataset download and verification script.
This script will properly download and cache C4 dataset files.
"""

import os
import sys
import time
from pathlib import Path
from datasets import load_dataset, Dataset
from datasets.utils.file_utils import cached_path

def setup_environment():
    """Set up environment variables for better HuggingFace handling."""
    os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "1200"  # 20 minutes
    os.environ["REQUESTS_TIMEOUT"] = "1200"
    os.environ["HF_DATASETS_CACHE"] = "/data"
    os.environ["HF_HOME"] = "/data"
    os.environ["HF_DATASETS_OFFLINE"] = "0"
    os.environ["TRANSFORMERS_OFFLINE"] = "0"
    
    # Create cache directory if it doesn't exist
    cache_dir = Path("/data")
    cache_dir.mkdir(parents=True, exist_ok=True)
    print(f"✅ Cache directory: {cache_dir}")

def download_c4_non_streaming(max_retries=3):
    """Download C4 dataset in non-streaming mode to force caching."""
    setup_environment()
    
    for attempt in range(max_retries):
        try:
            print(f"🔄 Attempt {attempt + 1}/{max_retries}: Downloading C4 dataset (non-streaming)...")
            
            # Try to download a small subset first to test connectivity
            dataset = load_dataset(
                "allenai/c4", 
                name="en", 
                split="train[:1%]",  # Download only 1% for testing
                trust_remote_code=True,
                cache_dir="/data"
            )
            
            print(f"✅ Successfully downloaded subset! Size: {len(dataset)} samples")
            
            # Test data access
            sample = dataset[0]
            print(f"✅ Sample text length: {len(sample['text'])}")
            
            return True
            
        except Exception as e:
            print(f"❌ Attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                wait_time = (attempt + 1) * 30
                print(f"⏳ Waiting {wait_time} seconds before retry...")
                time.sleep(wait_time)
            else:
                print("❌ All non-streaming download attempts failed!")
                return False
    
    return False

def verify_streaming_dataset():
    """Verify that streaming dataset works."""
    setup_environment()
    
    try:
        print("🔄 Testing streaming dataset access...")
        
        # Test streaming dataset - this is what torchtitan actually uses
        dataset = load_dataset(
            "allenai/c4", 
            name="en", 
            split="train", 
            streaming=True,
            trust_remote_code=True
        )
        
        print("✅ Streaming dataset loaded successfully")
        
        # Test iteration
        iterator = iter(dataset)
        for i in range(5):  # Test first 5 samples
            sample = next(iterator)
            print(f"  Sample {i+1}: text length = {len(sample['text'])}")
        
        print("✅ Streaming dataset verification successful!")
        return True
        
    except Exception as e:
        print(f"❌ Streaming dataset verification failed: {e}")
        return False

def test_offline_mode():
    """Test if dataset works in offline mode."""
    print("🔄 Testing offline mode...")
    
    # Temporarily set offline mode
    os.environ["HF_DATASETS_OFFLINE"] = "1"
    os.environ["TRANSFORMERS_OFFLINE"] = "1"
    
    try:
        dataset = load_dataset(
            "allenai/c4", 
            name="en", 
            split="train", 
            streaming=True,
            trust_remote_code=True
        )
        
        iterator = iter(dataset)
        sample = next(iterator)
        print("✅ Offline mode works! Dataset is properly cached.")
        return True
        
    except Exception as e:
        print(f"❌ Offline mode failed: {e}")
        print("   This means the dataset files are not properly cached.")
        return False
    
    finally:
        # Reset to online mode
        os.environ["HF_DATASETS_OFFLINE"] = "0"
        os.environ["TRANSFORMERS_OFFLINE"] = "0"

def check_cache_contents():
    """Check what's actually in the cache."""
    cache_dir = Path("/data")
    
    print("📁 Cache directory contents:")
    
    # Check hub directory
    hub_dir = cache_dir / "hub"
    if hub_dir.exists():
        c4_dirs = list(hub_dir.glob("*c4*"))
        for c4_dir in c4_dirs:
            print(f"  📂 {c4_dir}")
            
            # Check snapshots
            snapshots_dir = c4_dir / "snapshots"
            if snapshots_dir.exists():
                for snapshot in snapshots_dir.iterdir():
                    if snapshot.is_dir():
                        files = list(snapshot.iterdir())
                        print(f"    📁 {snapshot.name}: {len(files)} files")
                        for file in files[:5]:  # Show first 5 files
                            print(f"      📄 {file.name}")
    
    # Check datasets directory
    datasets_dir = cache_dir / "datasets"
    if datasets_dir.exists():
        c4_files = list(datasets_dir.glob("*c4*"))
        for c4_file in c4_files:
            print(f"  📄 {c4_file}")

def main():
    """Main function to run all tests."""
    print("🚀 Starting comprehensive C4 dataset verification...")
    print("=" * 60)
    
    # Check current cache contents
    check_cache_contents()
    print()
    
    # Test streaming dataset (what torchtitan uses)
    streaming_success = verify_streaming_dataset()
    print()
    
    if not streaming_success:
        print("⚠️  Streaming failed, trying non-streaming download...")
        download_success = download_c4_non_streaming()
        print()
        
        if download_success:
            # Test streaming again
            streaming_success = verify_streaming_dataset()
            print()
    
    # Test offline mode
    offline_success = test_offline_mode()
    print()
    
    # Final summary
    print("=" * 60)
    print("📊 SUMMARY:")
    print(f"   Streaming dataset: {'✅ PASS' if streaming_success else '❌ FAIL'}")
    print(f"   Offline mode: {'✅ PASS' if offline_success else '❌ FAIL'}")
    
    if streaming_success:
        print("\n🎉 SUCCESS! Your C4 dataset is properly configured.")
        print("   You can now run your training experiments.")
    else:
        print("\n❌ FAILURE! C4 dataset is not working properly.")
        print("   Recommendations:")
        print("   1. Check your internet connection")
        print("   2. Use 'c4_test' dataset instead")
        print("   3. Try a different dataset")
    
    return streaming_success

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1) 