#!/usr/bin/env python3
"""
Script to pre-download and cache torchvision models locally.
Run this script on a machine with internet access to cache models for offline use.
"""

import os
import torch
import torchvision.models as models

def cache_model(model_name, cache_dir):
    """Download and cache a torchvision model"""
    os.makedirs(cache_dir, exist_ok=True)
    cache_path = os.path.join(cache_dir, f"{model_name}.pth")
    
    if os.path.exists(cache_path):
        print(f"Model {model_name} already cached at {cache_path}")
        return
    
    print(f"Downloading and caching {model_name}...")
    
    try:
        if model_name == "resnet18":
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        elif model_name == "mobilenet_v3_small":
            model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
        elif model_name == "efficientnet_b0":
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
        # Save model state dict
        torch.save(model.state_dict(), cache_path)
        print(f"✓ Cached {model_name} to {cache_path}")
        
        # Print model info
        total_params = sum(p.numel() for p in model.parameters())
        print(f"  Model size: {total_params:,} parameters")
        
    except Exception as e:
        print(f"✗ Failed to cache {model_name}: {e}")

def main():
    print("Caching torchvision models for offline use...")
    
    models_to_cache = [
        "resnet18",
        "mobilenet_v3_small", 
        "efficientnet_b0"
    ]
    
    cache_dir = "encoders"
    print(f"Cache directory: {os.path.abspath(cache_dir)}")
    
    for model_name in models_to_cache:
        cache_model(model_name, cache_dir)
    
    print("\nDone! You can now copy the 'cached_models' directory to your cluster.")
    print("File sizes:")
    
    for model_name in models_to_cache:
        cache_path = os.path.join(cache_dir, f"{model_name}.pth")
        if os.path.exists(cache_path):
            size_mb = os.path.getsize(cache_path) / (1024 * 1024)
            print(f"  {model_name}.pth: {size_mb:.1f} MB")

if __name__ == "__main__":
    main()
