#!/usr/bin/env python3
"""
Setup and Run Script for VQ-Diffusion

This script helps you set up and run VQ-Diffusion image generation.
It checks if the required model checkpoints are downloaded and provides clear instructions.
"""

import os
import subprocess
import sys

def check_model_files():
    """Check if the required model files exist."""
    model_files = [
        '/checkpoints/pretrained_model/ithq_learnable.pth',
        '/checkpoints/pretrained_model/coco_learnable.pth',
        '/checkpoints/pretrained_model/cub_learnable.pth',
        '/checkpoints/pretrained_model/cc_learnable.pth',
        '/checkpoints/pretrained_model/imagenet_learnable.pth',
        'configs/ithq.yaml',
        'configs/imagenet.yaml'
    ]
    
    missing_files = []
    for file_path in model_files:
        if not os.path.exists(file_path):
            missing_files.append(file_path)
    
    return missing_files

def download_checkpoints():
    """Download the model checkpoints."""
    print("Downloading VQ-Diffusion model checkpoints...")
    print("This may take a while depending on your internet connection.")
    
    try:
        result = subprocess.run(['bash', 'vqdiffusion_download_checkpoints.sh'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            print("✅ Checkpoints downloaded successfully!")
            return True
        else:
            print("❌ Error downloading checkpoints:")
            print(result.stderr)
            return False
    except FileNotFoundError:
        print("❌ Download script not found: vqdiffusion_download_checkpoints.sh")
        return False

def main():
    print("=== VQ-Diffusion Setup and Run ===")
    print()
    
    # Check if model files exist
    missing_files = check_model_files()
    
    if missing_files:
        print("❌ Some required model files are missing:")
        for file_path in missing_files:
            print(f"   - {file_path}")
        print()
        
        print("📥 You need to download the model checkpoints first.")
        response = input("Would you like to download them now? (y/n): ").lower().strip()
        
        if response in ['y', 'yes']:
            if download_checkpoints():
                print("\n✅ Setup complete! You can now run the generation script.")
                print("Run: python generate_images.py")
            else:
                print("\n❌ Setup failed. Please check the error messages above.")
                sys.exit(1)
        else:
            print("\n📋 Manual setup instructions:")
            print("1. Run: bash vqdiffusion_download_checkpoints.sh")
            print("2. Wait for the download to complete")
            print("3. Run: python generate_images.py")
            sys.exit(1)
    else:
        print("✅ All required model files are present!")
        print("🚀 Ready to generate images!")
        print()
        
        response = input("Would you like to run the generation script now? (y/n): ").lower().strip()
        if response in ['y', 'yes']:
            print("\n🚀 Starting image generation...")
            subprocess.run([sys.executable, 'generate_images.py'])
        else:
            print("\n📋 To run the generation script later, use:")
            print("   python generate_images.py")

if __name__ == '__main__':
    main()
