#!/usr/bin/env python3
"""
Script to run the training and evaluation pipeline for a single k value.
Usage: python run_single_k.py <k_value>
Example: python run_single_k.py 5
"""

import subprocess
import sys
from pathlib import Path


def run_command(cmd, description):
    """Run a command and handle errors."""
    print(f"Running: {description}")
    print(f"Command: {' '.join(cmd)}")
    print("-" * 50)
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    if result.returncode != 0:
        print(f"ERROR: {description} failed")
        print("STDOUT:", result.stdout)
        print("STDERR:", result.stderr)
        return False
    
    print("SUCCESS: Command completed")
    if result.stdout.strip():
        print("Output:", result.stdout.strip())
    print()
    return True


def main():
    if len(sys.argv) != 2:
        print("Usage: python run_single_k.py <k_value>")
        print("Example: python run_single_k.py 5")
        print()
        print("This will run the full pipeline for the specified k value.")
        print("Checkpoints will be saved as best_k<k>.pt")
        print("Figures will be saved as hidden_pca_k<k>.png")
        sys.exit(1)
    
    try:
        k_value = int(sys.argv[1])
    except ValueError:
        print("ERROR: k_value must be an integer")
        sys.exit(1)
    
    print("=" * 60)
    print(f"Running pipeline for k={k_value}")
    print("=" * 60)
    print()
    
    # Training command
    train_cmd = [
        sys.executable, "-m", "src.train",
        "--Lx", "10", "--Ly", "3", "--n_colors", "10", "--obs_size", "3",
        "--k", str(k_value), "--hidden", "128",
        "--n_traj_train", "1000", "--n_traj_val", "100", "--T", "50",
        "--batch_size", "128", "--epochs", "10", "--lr", "0.002",
        "--outdir", "outputs"
    ]
    
    if not run_command(train_cmd, f"Training for k={k_value}"):
        print(f"ERROR: Training failed for k={k_value}")
        sys.exit(1)
    
    # PCA evaluation command
    pca_cmd = [
        sys.executable, "-m", "src.eval.pca_hidden",
        "--ckpt", f"outputs/checkpoints/best_k{k_value}.pt",
        "--n_traj_eval", "100", "--T_eval", "50",
        "--Lx", "10", "--Ly", "3", "--n_colors", "10", "--obs_size", "3",
        "--k", str(k_value), "--seed", "123",
        "--outdir", "outputs"
    ]
    
    if not run_command(pca_cmd, f"PCA evaluation for k={k_value}"):
        print(f"ERROR: PCA evaluation failed for k={k_value}")
        sys.exit(1)
    
    print("=" * 60)
    print(f"Pipeline completed for k={k_value}!")
    print("=" * 60)
    print()
    print("Results saved in:")
    print(f"- Checkpoint: outputs/checkpoints/best_k{k_value}.pt")
    print(f"- Figure: outputs/figures/hidden_pca_k{k_value}.png")


if __name__ == "__main__":
    main()
