#!/usr/bin/env python3
"""
Download HuggingFace models used by KernelBench level4 to a local cache.

This script pre-downloads all models so they can be bind-mounted into Docker
containers for offline inference.

Usage:
    python scripts/download_hf_models.py
    python scripts/download_hf_models.py --cache-dir /path/to/cache
"""

import argparse
import os
import sys
from pathlib import Path

# Models used in KernelBench/level4
KERNELBENCH_MODELS = [
    "gpt2",
    "EleutherAI/gpt-neo-2.7B",
    "facebook/opt-1.3b",
    "facebook/bart-large",
    "google/bigbird-roberta-base",
    "google/electra-small-discriminator",
    "google/reformer-enwik8",
]


def download_models(cache_dir: str, models: list[str] = None):
    """Download models to the specified cache directory."""
    try:
        from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
    except ImportError:
        print("[ERROR] transformers package not installed. Run: pip install transformers")
        sys.exit(1)
    
    # Set HuggingFace cache directory using modern env vars (before any downloads)
    # HF_HOME is the root, HF_HUB_CACHE is where models are stored
    os.environ["HF_HOME"] = cache_dir
    os.environ["HF_HUB_CACHE"] = cache_dir
    
    target_models = models if models else KERNELBENCH_MODELS
    
    print(f"[INFO] Downloading {len(target_models)} models to: {cache_dir}")
    print(f"[INFO] Models: {target_models}")
    print()
    
    success = []
    failed = []
    
    for model_name in target_models:
        print(f"[{len(success) + len(failed) + 1}/{len(target_models)}] Downloading: {model_name}")
        try:
            # Download config first (lightweight)
            print(f"  - Downloading config...")
            config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
            
            # Download tokenizer
            print(f"  - Downloading tokenizer...")
            try:
                tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
            except Exception as e:
                print(f"  - [WARN] Tokenizer not available: {e}")
            
            # Download model weights (the heavy part)
            print(f"  - Downloading model weights...")
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, cache_dir=cache_dir)
            
            # Free memory
            del model
            del config
            
            print(f"  - [OK] {model_name}")
            success.append(model_name)
            
        except Exception as e:
            print(f"  - [FAILED] {model_name}: {e}")
            failed.append((model_name, str(e)))
    
    print()
    print("=" * 60)
    print(f"[SUMMARY] Downloaded: {len(success)}, Failed: {len(failed)}")
    print(f"[INFO] Cache directory: {cache_dir}")
    
    if failed:
        print("\n[FAILED MODELS]")
        for name, err in failed:
            print(f"  - {name}: {err}")
    
    if success:
        print("\n[SUCCESS] Models are ready for offline use.")
        print(f"[INFO] Set HF_HOME={cache_dir} when running inference.")
    
    return len(failed) == 0


def main():
    parser = argparse.ArgumentParser(
        description="Download HuggingFace models for KernelBench level4"
    )
    parser.add_argument(
        '--cache-dir', type=str, 
        default='/mnt/cache/huggingface_models',
        help='Directory to store downloaded models (default: /mnt/cache/huggingface_models)'
    )
    parser.add_argument(
        '--models', type=str, default=None,
        help='Comma-separated list of model names to download (default: all KernelBench models)'
    )
    
    args = parser.parse_args()
    
    cache_dir = os.path.abspath(args.cache_dir)
    os.makedirs(cache_dir, exist_ok=True)
    
    models = [m.strip() for m in args.models.split(',')] if args.models else None
    
    success = download_models(cache_dir, models)
    sys.exit(0 if success else 1)


if __name__ == "__main__":
    main()
