#!/usr/bin/env python3
import os
import subprocess
import argparse
import sys

# ==============================================================================
# 1. CLUSTER CONFIGURATION
# ==============================================================================
ACCOUNT_NAME = "raiselab"
GPU_PARTITION = "gpu"
CPU_PARTITION = "standard"
CONDA_ENV_NAME = "deop_env"

# Resources
TIME_LIMIT_DATA = "04:00:00"  # Data gen is usually faster
TIME_LIMIT_TRAIN = "12:00:00"
MEM_CPU = "32G"
MEM_GPU = "16G"

# ==============================================================================
# 2. EXPERIMENT CONFIGURATION
# ==============================================================================
BASIS_COUNTS = [2, 4, 6, 8, 10]

PROJECT_ROOT = os.getcwd()
CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/base.yaml")
OUTPUT_BASE_DIR = f"/scratch/{os.environ.get('USER')}/pde_op_ablation_study"
LOG_DIR = os.path.join(PROJECT_ROOT, "slurm_logs_ablation")

# Script Paths
SCRIPT_GEN_DATA = os.path.join(PROJECT_ROOT, "src/generate_data.py")
SCRIPT_TRAIN_PROP = os.path.join(PROJECT_ROOT, "src/train_propagator_deeponet.py")
SCRIPT_TRAIN_CTRL = os.path.join(PROJECT_ROOT, "src/train_recurrent_controller.py")
SCRIPT_EVAL = os.path.join(PROJECT_ROOT, "src/ablation_study_basis_v2.py")

# Run ID Templates
PROP_ID_TEMPLATE = "heat_ablation_prop_m{}"
CTRL_ID_TEMPLATE = "heat_ablation_ctrl_m{}"

# Hyperparameters (Static parts)
# Note: --num_basis_functions is injected dynamically in the loop
PROP_ARGS = (
    "--learning_rate 1e-3 --optimizer adamw "
    "--latent_dim 512 --branch_depth 4 --branch_width 512 "
    "--trunk_depth 4 --trunk_width 512 --activation_fn relu"
)

CTRL_ARGS = (
    "--learning_rate 1e-4 --optimizer adamw "
    "--hidden_dim 256 --num_layers 2 --activation_fn relu "
    "--terminal_weight 2.0 --running_weight 0.1 --effort_weight 1e-5"
)

# ==============================================================================
# 3. SLURM UTILITIES
# ==============================================================================
def submit_sbatch(job_name, output_log, partition, time, mem, gres, cmd_string, dependency=None):
    """Constructs and submits a SLURM script via stdin."""
    
    # SLURM Header
    sbatch_script = f"#!/bin/bash\n"
    sbatch_script += f"#SBATCH --job-name={job_name}\n"
    sbatch_script += f"#SBATCH --output={output_log}\n"
    sbatch_script += f"#SBATCH --account={ACCOUNT_NAME}\n"
    sbatch_script += f"#SBATCH --partition={partition}\n"
    sbatch_script += f"#SBATCH --time={time}\n"
    sbatch_script += f"#SBATCH --mem={mem}\n"
    sbatch_script += f"#SBATCH --cpus-per-task=4\n"
    if gres:
        sbatch_script += f"#SBATCH --gres={gres}\n"
    if dependency:
        sbatch_script += f"#SBATCH --dependency=afterok:{dependency}\n"
    
    # Environment Setup
    sbatch_script += "\n"
    sbatch_script += "module purge\n"
    sbatch_script += "module load miniforge cuda\n" 
    sbatch_script += f"source activate {CONDA_ENV_NAME}\n"
    sbatch_script += "export PYTHONPATH=$PYTHONPATH:.\n"
    
    # The Command
    sbatch_script += f"\necho '--- Running: {job_name} ---'\n"
    sbatch_script += f"{cmd_string}\n"

    # Submit process
    res = subprocess.run(
        ["sbatch", "--parsable"], 
        input=sbatch_script, 
        text=True, 
        capture_output=True, 
        check=True
    )
    job_id = res.stdout.strip()
    return job_id

# ==============================================================================
# 4. MODES
# ==============================================================================

def mode_gen_data():
    """
    Submits 5 parallel Data Generation jobs (CPU).
    Useful if you just want to create datasets without training yet.
    """
    print(f"\n{'='*60}")
    print("MODE: DATA GENERATION ONLY (CPU)")
    print(f"{'='*60}")
    
    for m in BASIS_COUNTS:
        data_cmd = (
            f"python -u {SCRIPT_GEN_DATA} "
            f"--config_path {CONFIG_PATH} "
            f"--output_base_dir {OUTPUT_BASE_DIR} "
            f"--num_basis_functions {m}"
        )
        
        jid = submit_sbatch(
            job_name=f"data_m{m}",
            output_log=f"{LOG_DIR}/data_m{m}.out",
            partition=CPU_PARTITION,
            time=TIME_LIMIT_DATA,
            mem=MEM_CPU,
            gres=None,
            cmd_string=data_cmd
        )
        print(f"[M={m}] Data Gen submitted. Job ID: {jid}")

def mode_train_don():
    """
    Submits 5 parallel chains: [Data Gen (CPU) -> Train Propagator (GPU)]
    """
    print(f"\n{'='*60}")
    print("MODE: DYNAMICS PIPELINE (Data Gen -> Train Propagator)")
    print(f"{'='*60}")
    
    for m in BASIS_COUNTS:
        prop_run_id = PROP_ID_TEMPLATE.format(m)
        
        # --- A. Submit Data Generation (CPU) ---
        data_cmd = (
            f"python -u {SCRIPT_GEN_DATA} "
            f"--config_path {CONFIG_PATH} "
            f"--output_base_dir {OUTPUT_BASE_DIR} "
            f"--num_basis_functions {m}"
        )
        
        data_jid = submit_sbatch(
            job_name=f"data_m{m}",
            output_log=f"{LOG_DIR}/data_m{m}.out",
            partition=CPU_PARTITION,
            time=TIME_LIMIT_DATA,
            mem=MEM_CPU,
            gres=None,
            cmd_string=data_cmd
        )
        print(f"[M={m}] Data Gen submitted. Job ID: {data_jid}")

        # --- B. Submit Propagator Training (GPU) - Depends on Data ---
        prop_cmd = (
            f"python -u {SCRIPT_TRAIN_PROP} "
            f"--config_path {CONFIG_PATH} "
            f"--output_base_dir {OUTPUT_BASE_DIR} "
            f"--run_id {prop_run_id} "
            f"--num_basis_functions {m} "
            f"{PROP_ARGS}"
        )
        
        prop_jid = submit_sbatch(
            job_name=f"prop_m{m}",
            output_log=f"{LOG_DIR}/{prop_run_id}.out",
            partition=GPU_PARTITION,
            time=TIME_LIMIT_TRAIN,
            mem=MEM_GPU,
            gres="gpu:1",
            cmd_string=prop_cmd,
            dependency=data_jid
        )
        print(f"[M={m}] Propagator submitted. Job ID: {prop_jid} (Waits for {data_jid})")

def mode_train_ctrl():
    """
    Submits 5 parallel Controller jobs.
    Assumes Propagators (and Data) already exist.
    """
    print(f"\n{'='*60}")
    print("MODE: CONTROLLER TRAINING (PARALLEL)")
    print(f"{'='*60}")
    print("⚠️  ENSURE PROPAGATOR TRAINING IS FINISHED FIRST! ⚠️\n")
    
    for m in BASIS_COUNTS:
        prop_run_id = PROP_ID_TEMPLATE.format(m)
        ctrl_run_id = CTRL_ID_TEMPLATE.format(m)
        
        ctrl_cmd = (
            f"python -u {SCRIPT_TRAIN_CTRL} "
            f"--config_path {CONFIG_PATH} "
            f"--output_base_dir {OUTPUT_BASE_DIR} "
            f"--run_id {ctrl_run_id} "
            f"--deeponet_run_id {prop_run_id} "
            f"--num_basis_functions {m} "
            f"{CTRL_ARGS}"
        )
        
        ctrl_jid = submit_sbatch(
            job_name=f"ctrl_m{m}",
            output_log=f"{LOG_DIR}/{ctrl_run_id}.out",
            partition=GPU_PARTITION,
            time=TIME_LIMIT_TRAIN,
            mem=MEM_GPU,
            gres="gpu:1",
            cmd_string=ctrl_cmd
        )
        print(f"[M={m}] Controller submitted. Job ID: {ctrl_jid}")

def mode_evaluate():
    """
    Runs the evaluation locally.
    """
    print(f"\n{'='*60}")
    print("MODE: EVALUATION (LOCAL)")
    print(f"{'='*60}")
    
    m_str_list = [str(m) for m in BASIS_COUNTS]
    m_args = " ".join(m_str_list)
    
    cmd = (
        f"python {SCRIPT_EVAL} "
        f"--config_path {CONFIG_PATH} "
        f"--output_base_dir {OUTPUT_BASE_DIR} "
        f"--run_id_template '{CTRL_ID_TEMPLATE}' "
        f"--basis_counts {m_args}"
    )
    
    print(f"Running: {cmd}")
    try:
        subprocess.run(cmd, shell=True, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Evaluation failed with error code {e.returncode}")

# ==============================================================================
# MAIN
# ==============================================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parallel SLURM Manager for Ablation")
    parser.add_argument("mode", choices=["gen_data", "train_don", "train_ctrl", "evaluate"], 
                        help="Phase to run.")
    args = parser.parse_args()
    
    # Ensure log directory exists
    if not os.path.exists(LOG_DIR):
        os.makedirs(LOG_DIR, exist_ok=True)
    
    # Ensure Output Dir exists
    final_out_dir = os.path.expandvars(OUTPUT_BASE_DIR)
    if not os.path.exists(final_out_dir):
        os.makedirs(final_out_dir, exist_ok=True)

    if args.mode == "gen_data":
        mode_gen_data()
    elif args.mode == "train_don":
        mode_train_don()
    elif args.mode == "train_ctrl":
        mode_train_ctrl()
    elif args.mode == "evaluate":
        mode_evaluate()