#!/usr/bin/env python
# slurm_batch_cpu.py - Script to generate and submit CPU-based SLURM jobs for scale experiments

# Selective running
# python slurm_batch_cpu.py --training-modes minimal --scales 1.0 --seeds 42 123 --optimizers adamw

import argparse
import os
import itertools
from pathlib import Path
from datetime import datetime
import sys
import numpy as np

def main():
    parser = argparse.ArgumentParser(description='Generate and submit SLURM jobs for scale experiments')
    
    # Configuration options
    parser.add_argument('--training-modes', type=str, nargs='+', 
                        default=['minimal', 'balanced_16', 'balanced_32', 'balanced_48', 'maximal'],
                        help='Training modes to test')
    
    DEFAULT_SCALES = np.logspace(np.log10(0.01), np.log10(3), 5).tolist()
    parser.add_argument('--scales', type=float, nargs='+', default=DEFAULT_SCALES,
                    help='Scale parameters to test')
    parser.add_argument('--seeds', type=int, nargs='+', 
                        default=[42, 123, 234, 345, 456, 567, 678, 789, 890, 901],
                        help='Seeds to use')
    parser.add_argument('--optimizers', type=str, nargs='+', default=['sgd'],
                        choices=['adamw', 'sgd'], help='Optimizers to use')
    parser.add_argument('--results-parent-dir', type=str, default='./results',
                        help='Parent directory for results')
    
    # SLURM specific parameters
    parser.add_argument('--time', type=str, default='8:00:00',
                        help='Maximum runtime in format HH:MM:SS')
    parser.add_argument('--mem', type=int, default=64000,
                        help='Memory to request (in MB)')
    parser.add_argument('--email', type=str, default=None,
                        help='Email address for job notifications')
    parser.add_argument('--no-submit', action='store_true',
                        help='Generate job scripts but do not submit to SLURM')
    parser.add_argument('--cpu-partition', type=str, default='price',
                    help='CPU partition to use (default: price)')
    parser.add_argument('--cpus-per-task', type=int, default=16,
                    help='Number of CPU cores per task (default: 16)')
    
    args = parser.parse_args()
    
    # Create a timestamped directory for this batch of experiments
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    batch_dir = Path(args.results_parent_dir) / f"slurm_cpu_batch_{timestamp}"
    batch_dir.mkdir(exist_ok=True, parents=True)
    
    # Create a directory for SLURM scripts
    slurm_dir = batch_dir / "slurm_scripts"
    slurm_dir.mkdir(exist_ok=True)
    
    print(f"Batch experiments will be saved in: {batch_dir}")
    print(f"SLURM scripts will be saved in: {slurm_dir}")
    
    # Get current hostname for exclusion (mandatory for this cluster)
    try:
        current_hostname = os.popen('hostname -s').read().strip()
        print(f"Will exclude current login node: {current_hostname}")
    except Exception as e:
        print(f"Error: Could not get hostname for exclusion: {e}")
        print("Aborting job submission as login node exclusion is mandatory.")
        sys.exit(1)
    
    # Generate all experiment combinations
    experiment_configs = list(itertools.product(
        args.training_modes,
        args.scales,
        args.seeds,
        args.optimizers
    ))
    
    total_experiments = len(experiment_configs)
    print(f"Total experiments to run: {total_experiments}")
    
    # Track all job IDs for a master script
    all_job_ids = []
    
    # Create a master script to track all jobs
    master_script_path = slurm_dir / "master_job_list.txt"
    with open(master_script_path, 'w') as master_file:
        master_file.write(f"# SLURM batch generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        master_file.write(f"# Total jobs: {total_experiments}\n\n")
        master_file.write(f"# Excluding login node: {current_hostname}\n")
        master_file.write("\n")
    
    # Generate a SLURM script for each experiment configuration
    for config in experiment_configs:
        training_mode, scale, seed, optimizer = config
        
        # Create a specific results directory for this experiment
        exp_dir = batch_dir / f"cpu_experiment_{training_mode}_scale{scale}_{optimizer}_seed{seed}"
        exp_dir.mkdir(exist_ok=True)
        
        # Create a unique job name
        job_name = f"cpu_scale{scale}_{training_mode}_{seed}"
        
        # Create the SLURM script for this experiment
        slurm_script_path = slurm_dir / f"{job_name}.sh"
        
        with open(slurm_script_path, 'w') as f:
            f.write("#!/bin/bash\n")
            f.write("\n# ---- Resource Allocation ----\n")
            f.write(f"#SBATCH --job-name={job_name}\n")
            f.write("#SBATCH --nodes=1\n")
            f.write(f"#SBATCH --ntasks=1\n")
            f.write(f"#SBATCH --cpus-per-task={args.cpus_per_task}\n")
            f.write(f"#SBATCH --mem={args.mem}\n")
            f.write(f"#SBATCH --time={args.time}\n")
            f.write(f"#SBATCH --partition={args.cpu_partition}\n")
            f.write(f"#SBATCH --exclude={current_hostname}\n")
            
            # Add email notifications if specified
            if args.email:
                f.write("#SBATCH --mail-type=BEGIN,END,FAIL\n")
                f.write(f"#SBATCH --mail-user={args.email}\n")
            
            # Output and error files
            f.write(f"#SBATCH --output={slurm_dir}/{job_name}.out\n")
            f.write(f"#SBATCH --error={slurm_dir}/{job_name}.err\n")
            f.write("#SBATCH --export=ALL\n\n")
            
            
            # Add the module loads if necessary (adjust according to your cluster)
            f.write("\n# ---- Environment Setup ----\n")
            f.write("# Load necessary modules\n")
            f.write("module purge\n")
            f.write("\n")

            # Simplify conda activation:
            f.write("# Activate conda environment\n")
            f.write("source $HOME/.bashrc\n")
            f.write("conda activate cprornn2\n\n")
            
            # Write the actual command to run the experiment
            f.write("\n# ---- Run Experiment ----\n")
            f.write(f"python run_scale_experiment.py \\\n")
            f.write(f"    --training-mode {training_mode} \\\n")
            f.write(f"    --scale {scale} \\\n")
            f.write(f"    --seed {seed} \\\n")
            f.write(f"    --optimizer {optimizer} \\\n")
            f.write(f"    --results-dir {exp_dir} \\\n")
            f.write(f"    --description \"SLURM batch: {training_mode} training, scale {scale}, {optimizer} optimizer, seed {seed}\"\n")
        
        # Make the script executable
        os.chmod(slurm_script_path, 0o755)
        
        # Submit the job if not in no-submit mode
        if not args.no_submit:
            cmd = f"sbatch {slurm_script_path}"
            print(f"Submitting job: {job_name}")
            
            # Get the job ID
            try:
                result = os.popen(cmd).read().strip()
                job_id = result.split()[-1]
                all_job_ids.append(job_id)
                
                # Add to master list
                with open(master_script_path, 'a') as master_file:
                    master_file.write(f"Job ID: {job_id} - {job_name}\n")
                
                print(f"Submitted job ID: {job_id}")
            except Exception as e:
                print(f"Error submitting job: {e}")
        else:
            print(f"Generated script for: {job_name}")
    
    # Create a script to check the status of all jobs
    if all_job_ids and not args.no_submit:
        status_script_path = slurm_dir / "check_job_status.sh"
        with open(status_script_path, 'w') as f:
            f.write("#!/bin/bash\n\n")
            f.write("# Check the status of all jobs in the batch\n")
            f.write("echo 'Checking status of all batch jobs...'\n")
            f.write(f"squeue -u $USER | grep -f <(echo")
            for job_id in all_job_ids:
                f.write(f" {job_id}")
            f.write(")\n")
        
        os.chmod(status_script_path, 0o755)
        print(f"\nCreated job status check script: {status_script_path}")
    
    # Create a script to cancel all jobs if needed
    if all_job_ids and not args.no_submit:
        cancel_script_path = slurm_dir / "cancel_all_jobs.sh"
        with open(cancel_script_path, 'w') as f:
            f.write("#!/bin/bash\n\n")
            f.write("# Cancel all jobs in the batch\n")
            f.write("echo 'Cancelling all batch jobs...'\n")
            f.write("scancel")
            for job_id in all_job_ids:
                f.write(f" {job_id}")
            f.write("\n")
        
        os.chmod(cancel_script_path, 0o755)
        print(f"Created job cancellation script: {cancel_script_path}")
    
    # Final message
    if args.no_submit:
        print(f"\nGenerated {total_experiments} job scripts. To submit, run them individually with 'sbatch' or use:")
        print(f"for script in {slurm_dir}/*.sh; do [ -f \"$script\" ] && sbatch \"$script\"; done")
    else:
        print(f"\nSubmitted {len(all_job_ids)} out of {total_experiments} CPU-based jobs to SLURM")
        print(f"To check job status: ./check_job_status.sh")
        print(f"To cancel all jobs: ./cancel_all_jobs.sh")
    
    print(f"\nResults will be saved in: {batch_dir}")

if __name__ == "__main__":
    main()
