#!/usr/bin/env python3
"""
Parallel MLP Beta-DAgger Training Launcher - 20 Seeds 3-layer (256, 512, 128)

Train multiple standalone MLP models in parallel across multiple GPUs with different seeds.
This serves as a baseline to compare with GHN-generated weights.

Usage: python 4_train_mlp_parallel_20seeds.py
"""

import subprocess
import random
import time
import signal
import os
import sys
from pathlib import Path
from typing import Dict, List

# ============================================================================
# Configuration
# ============================================================================
TEACHER_CHECKPOINT = "runs/PickCube-v1__ppo_teacher__1__1766181047/final_ckpt.pt"
ENV_ID = "PickCube-v1"
TRAIN_DIR = "runs_dagger/pickcube_3x256_mlp_baseline"

# Training hyperparameters (matching beta_dagger_mlp_student.py defaults)
TOTAL_ITERATIONS = 200         # Number of DAgger iterations
BETA_DECAY_RATE = 0.97          # Exponential decay: beta = p^iteration
BC_UPDATES_PER_ITER = 50      # BC gradient steps per iteration
BATCH_SIZE = 4096               # BC training batch size
BUFFER_SIZE = 1_000_000           # Replay buffer size
LEARNING_RATE = 1e-3
MIN_LEARNING_RATE = 7e-4       # For cosine annealing
NUM_ENVS = 1024
NUM_STEPS = 50                 # Steps per rollout
EVAL_FREQ = 10                 # Evaluation frequency (iterations)
LOG_FREQ = 1                   # Logging frequency (iterations)

# GPU settings
NUM_GPUS = 1
GPU_IDS = [0]
JOBS_PER_GPU = 5  # Sequential execution specific

# Training seeds
NUM_SEEDS = 40  # Train 20 MLP models

# ============================================================================
# Validation
# ============================================================================
if not os.path.exists(TEACHER_CHECKPOINT):
    print(f"ERROR: Teacher not found: {TEACHER_CHECKPOINT}")
    sys.exit(1)

# Generate fixed 6-digit random seeds (deterministic)
random.seed(1)
seeds = [random.randint(100000, 999999) for _ in range(NUM_SEEDS)]

# Create output directory
Path(TRAIN_DIR).mkdir(parents=True, exist_ok=True)

TOTAL_JOBS = NUM_SEEDS
print("="*80)
print(f"Parallel MLP (Baseline) Beta-DAgger Training (3-layer: 256, 512, 128)")
print("="*80)
print(f"Jobs: {TOTAL_JOBS} MLP models ({NUM_SEEDS} different seeds)")
print(f"Teacher: {TEACHER_CHECKPOINT}")
print(f"Environment: {ENV_ID}")
print(f"Output: {TRAIN_DIR}")
print(f"GPUs: {GPU_IDS} ({NUM_GPUS} × {JOBS_PER_GPU} jobs per GPU)")
print("="*80)

# ============================================================================
# Job Management
# ============================================================================
class JobManager:
    def __init__(self, gpus: List[int], jobs_per_gpu: int):
        self.gpus = gpus
        self.jobs_per_gpu = jobs_per_gpu
        self.max_parallel = len(gpus) * jobs_per_gpu

        # Track running jobs: pid -> {exp_name, gpu, start_time}
        self.running_jobs: Dict[int, dict] = {}

        # Track GPU load
        self.gpu_job_count = {gpu: 0 for gpu in gpus}

        # Statistics
        self.completed = 0
        self.failed = 0
        self.skipped = 0

        # Setup signal handlers
        signal.signal(signal.SIGINT, self._cleanup)
        signal.signal(signal.SIGTERM, self._cleanup)

    def _cleanup(self, signum, frame):
        """Handle Ctrl+C gracefully"""
        print("\n")
        print("Caught interrupt signal. Killing all running jobs...")
        for pid in list(self.running_jobs.keys()):
            try:
                os.kill(pid, signal.SIGTERM)
            except ProcessLookupError:
                pass
        print("Cleanup complete. Exiting.")
        sys.exit(1)

    def get_available_gpu(self) -> int:
        """Get GPU with fewest running jobs"""
        return min(self.gpus, key=lambda g: self.gpu_job_count[g])

    def wait_for_slot(self):
        """Wait until a job slot is available"""
        while len(self.running_jobs) >= self.max_parallel:
            self._check_completed()
            time.sleep(1)

    def _check_completed(self):
        """Check for completed jobs and update counters"""
        for pid in list(self.running_jobs.keys()):
            try:
                status = os.waitpid(pid, os.WNOHANG)
                if status[0] != 0:  # Process finished
                    job_info = self.running_jobs.pop(pid)
                    self.gpu_job_count[job_info['gpu']] -= 1

                    # Check if successful (exit code 0)
                    exit_code = status[1] >> 8
                    if exit_code == 0:
                        self.completed += 1
                        status_msg = "OK"
                    else:
                        self.failed += 1
                        status_msg = f"FAIL (exit {exit_code})"

                    elapsed = time.time() - job_info['start_time']
                    print(f"[{self.completed + self.failed}/{TOTAL_JOBS}] {status_msg} {job_info['exp_name']} ({elapsed:.1f}s)")
            except ChildProcessError:
                if pid in self.running_jobs:
                    try:
                        os.kill(pid, 0)
                    except ProcessLookupError:
                        job_info = self.running_jobs.pop(pid)
                        self.gpu_job_count[job_info['gpu']] -= 1
                        self.failed += 1
                        print(f"[{self.completed + self.failed}/{TOTAL_JOBS}] LOST {job_info['exp_name']}")

    def launch_job(self, job_id: int, seed: int) -> bool:
        """
        Launch an MLP Beta-DAgger training job.
        """
        # Create experiment name with sequential prefix
        exp_name = f"{job_id:03d}_{ENV_ID}__beta_dagger_mlp_baseline__{seed}__{int(time.time())}"
        exp_dir = Path(TRAIN_DIR) / exp_name
        exp_dir.mkdir(parents=True, exist_ok=True)

        # Resume capability: skip if already trained
        if (exp_dir / "mlp_final_ckpt.pt").exists():
            print(f"[{job_id+1}/{TOTAL_JOBS}] Skipping {exp_name} (already exists)")
            self.skipped += 1
            return False

        # Wait for available slot
        self.wait_for_slot()

        # Get GPU
        gpu = self.get_available_gpu()

        # Build command
        cmd = [
            sys.executable, 'beta_dagger_mlp_student.py',
            '--env-id', ENV_ID,
            '--seed', str(seed),
            '--teacher-checkpoint', TEACHER_CHECKPOINT,
            '--total-iterations', str(TOTAL_ITERATIONS),
            '--beta-decay-rate', str(BETA_DECAY_RATE),
            '--bc-updates-per-iter', str(BC_UPDATES_PER_ITER),
            '--bc-batch-size', str(BATCH_SIZE),
            '--buffer-size', str(BUFFER_SIZE),
            '--learning-rate', str(LEARNING_RATE),
            '--min-learning-rate', str(MIN_LEARNING_RATE),
            '--num-envs', str(NUM_ENVS),
            '--num-steps', str(NUM_STEPS),
            '--eval-freq', str(EVAL_FREQ),
            '--log-freq', str(LOG_FREQ),
            '--train-dir', TRAIN_DIR,
            '--exp-name', exp_name,
            '--save-model',
            # MLP specific: 3x256 is default in the script, so no extra arg needed unless overrides
            # but for clarity, the script uses default (256, 256, 256)
        ]

        # Setup environment with GPU
        env = os.environ.copy()
        env['CUDA_VISIBLE_DEVICES'] = str(gpu)

        # Launch process
        proc = subprocess.Popen(
            cmd,
            env=env,
            cwd=os.getcwd()
        )

        # Track job
        self.running_jobs[proc.pid] = {
            'exp_name': exp_name,
            'gpu': gpu,
            'start_time': time.time()
        }
        self.gpu_job_count[gpu] += 1

        print(f"[{job_id+1}/{TOTAL_JOBS}] Launching Seed {seed} on GPU {gpu}")

        return True

    def wait_for_all(self):
        """Wait for all remaining jobs to complete"""
        if self.running_jobs:
            print("\nAll jobs launched. Waiting for completion...")

        while self.running_jobs:
            self._check_completed()
            time.sleep(3)

# ============================================================================
# Main Training Loop
# ============================================================================
def main():
    start_time = time.time()
    job_manager = JobManager(GPU_IDS, JOBS_PER_GPU)

    # Train MLP models with different seeds
    for job_id, seed in enumerate(seeds):
        job_manager.launch_job(job_id, seed)
        time.sleep(2)  # Small delay between launches

    # Wait for remaining jobs
    job_manager.wait_for_all()

    # Summary
    elapsed = time.time() - start_time
    hours = int(elapsed // 3600)
    minutes = int((elapsed % 3600) // 60)

    print()
    print("="*80)
    print(f"Complete! Summary:")
    print(f"  Total:     {TOTAL_JOBS} MLP models (Baseline)")
    print(f"  Completed: {job_manager.completed}")
    print(f"  Failed:    {job_manager.failed}")
    print(f"  Skipped:   {job_manager.skipped}")
    print(f"  Time:      {hours}h {minutes}m")
    print(f"  Output:    {TRAIN_DIR}")
    print("="*80)

if __name__ == "__main__":
    main()
