#!/usr/bin/env python3
"""
Script to run the training and evaluation pipeline for multiple k values.
Each k value will have its own checkpoint and figure files.
"""

import subprocess
import sys
from pathlib import Path


def run_command(cmd, description):
    """Run a command and handle errors."""
    print(f"Running: {description}")
    print(f"Command: {' '.join(cmd)}")
    print("-" * 50)
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    if result.returncode != 0:
        print(f"ERROR: {description} failed")
        print("STDOUT:", result.stdout)
        print("STDERR:", result.stderr)
        return False
    
    print("SUCCESS: Command completed")
    if result.stdout.strip():
        print("Output:", result.stdout.strip())
    print()
    return True


def main():
    # Define the k values to test
    k_values = [1, 3, 5, 10, 20, 40]
    
    print("=" * 60)
    print("Running pipeline for multiple k values")
    print("=" * 60)
    print()
    
    for k in k_values:
        print("=" * 50)
        print(f"Running pipeline for k={k}")
        print("=" * 50)
        print()
        
        # Training command
        train_cmd = [
            sys.executable, "-m", "src.train",
            "--Lx", "10", "--Ly", "3", "--n_colors", "10", "--obs_size", "3",
            "--k", str(k), "--hidden", "128",
            "--n_traj_train", "1000", "--n_traj_val", "100", "--T", "50",
            "--batch_size", "128", "--epochs", "10", "--lr", "0.002",
            "--outdir", "outputs"
        ]
        
        if not run_command(train_cmd, f"Training for k={k}"):
            print(f"ERROR: Training failed for k={k}")
            sys.exit(1)
        
        # PCA evaluation command
        pca_cmd = [
            sys.executable, "-m", "src.eval.pca_hidden",
            "--ckpt", f"outputs/checkpoints/best_k{k}.pt",
            "--n_traj_eval", "100", "--T_eval", "50",
            "--Lx", "10", "--Ly", "3", "--n_colors", "10", "--obs_size", "3",
            "--k", str(k), "--seed", "123",
            "--outdir", "outputs"
        ]
        
        if not run_command(pca_cmd, f"PCA evaluation for k={k}"):
            print(f"ERROR: PCA evaluation failed for k={k}")
            sys.exit(1)
        
        print(f"Pipeline completed for k={k}! Check outputs/figures/hidden_pca_k{k}.png")
        print()
    
    print("=" * 60)
    print("All pipelines completed successfully!")
    print("=" * 60)
    print()
    print("Results saved in:")
    print("- Checkpoints: outputs/checkpoints/best_k*.pt")
    print("- Figures: outputs/figures/hidden_pca_k*.png")
    print()


if __name__ == "__main__":
    main()
