#!/usr/bin/env python3
"""
Unified parallel executor for fine-tuning, TextFooler attacks, and A2T attacks.

This script provides a common framework for running multiple independent experiments
across multiple GPUs using dynamic GPU assignment to prevent GPU idling.

Features:
- Auto-discovery of checkpoints (searches for 'final_model' subfolders)
- Dynamic GPU assignment (prevents GPU starvation)
- Support for three workload types: training, textfooler, a2t
- Graceful shutdown with Ctrl-C
- Per-job logging and summary reports

Usage Examples:
  # Run TextFooler attacks on all checkpoints in parallel
  python parallel_executor.py textfooler \
    --search-root data_files/nlp_training_faster_learning_rate \
    --gpus 0,1,2,3

  # Run A2T attacks on specific checkpoints
  python parallel_executor.py a2t \
    --checkpoint-list checkpoints.txt \
    --gpus 0,1 \
    --num-samples 50

  # Run fine-tuning experiments in parallel
  python parallel_executor.py training \
    --config-overrides "model=[gpt2,bert] dataset=[ag_news,imdb] seed=[42,43]" \
    --gpus 0,1,2,3
"""

import argparse
import json
import os
import queue
import shlex
import signal
import subprocess
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import product
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

# Import from experiment_parallel.py for GPU discovery
sys.path.append(str(Path(__file__).parent))
try:
    from experiment_parallel import get_available_gpus
except ImportError:
    def get_available_gpus() -> List[int]:
        """Fallback GPU discovery if import fails"""
        try:
            import subprocess
            out = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"],
                stderr=subprocess.DEVNULL
            ).decode()
            return [int(x.strip()) for x in out.splitlines() if x.strip().isdigit()]
        except Exception:
            return [0, 1]  # Default fallback


class GracefulKiller:
    """Handle Ctrl-C gracefully"""
    def __init__(self):
        self.stop = False
        signal.signal(signal.SIGINT, self._signal_handler)
        signal.signal(signal.SIGTERM, self._signal_handler)

    def _signal_handler(self, *_):
        self.stop = True
        print("\n🛑 Shutdown signal received, stopping gracefully...")


def find_final_model_paths(
    roots: List[str], 
    require_files: Optional[List[str]] = None
) -> List[str]:
    """
    Find all 'final_model' directories under the given roots.
    
    Args:
        roots: Directory roots to search recursively
        require_files: Only keep dirs containing these files (e.g. ['pytorch_model.bin'])
        
    Returns:
        List of absolute paths to final_model directories
    """
    if require_files is None:
        require_files = ["pytorch_model.bin", "config.json"]
    
    checkpoints = []
    for root in roots:
        root_path = Path(root).resolve()
        if not root_path.exists():
            print(f"Warning: Search root does not exist: {root_path}")
            continue
            
        for path in root_path.rglob("final_model"):
            if not path.is_dir():
                continue
            if path.name != "final_model":
                continue
                
            # Check if required files exist
            if require_files:
                has_required = any((path / fname).exists() for fname in require_files)
                if not has_required:
                    continue
                    
            checkpoints.append(str(path.resolve()))
    
    return sorted(list(set(checkpoints)))  # Remove duplicates and sort


def read_checkpoint_list(file_path: str) -> List[str]:
    """Read checkpoint paths from a file (one per line)"""
    path = Path(file_path).resolve()
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint list file not found: {path}")
    
    checkpoints = []
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith('#'):
                ckpt_path = Path(line).resolve()
                if ckpt_path.exists() and ckpt_path.is_dir():
                    checkpoints.append(str(ckpt_path))
                else:
                    print(f"Warning: Checkpoint path does not exist: {ckpt_path}")
    
    return checkpoints


def parse_config_overrides(override_str: str) -> List[Dict[str, Any]]:
    """
    Parse config overrides for training jobs.
    
    Format: "model=[gpt2,bert] dataset=[ag_news,imdb] seed=[42,43]"
    Returns: List of job configurations
    """
    import re
    
    # Parse format like "key=[val1,val2] key2=[val3,val4]"
    pattern = r'(\w+)=\[([^\]]+)\]'
    matches = re.findall(pattern, override_str)
    
    if not matches:
        raise ValueError(f"Invalid config override format: {override_str}")
    
    # Build parameter combinations
    params = {}
    for key, values_str in matches:
        values = [v.strip() for v in values_str.split(',')]
        params[key] = values
    
    # Generate all combinations
    keys = list(params.keys())
    value_combinations = list(product(*[params[key] for key in keys]))
    
    jobs = []
    for combination in value_combinations:
        job = dict(zip(keys, combination))
        jobs.append(job)
    
    return jobs


def run_textfooler_job(checkpoint_path: str, gpu_id: int, args: argparse.Namespace) -> Dict[str, Any]:
    """Run a single TextFooler attack job"""
    job_name = f"textfooler_{Path(checkpoint_path).parent.name}"
    
    # Set up environment with GPU isolation and proper virtual environment activation
    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    env["PYTHONPATH"] = str(Path(args.textfooler_root).parent)  # Add project root to PYTHONPATH
    
    # Properly activate the TextFooler virtual environment
    textfooler_env_path = Path(args.textfooler_python).parent.parent  # Get venv root from python path
    env["VIRTUAL_ENV"] = str(textfooler_env_path)
    env["PATH"] = f"{textfooler_env_path}/bin:{env.get('PATH', '')}"
    
    # Extract dataset name and seed from experiment_config.json
    dataset_name = None
    seed = None
    experiment_config_path = Path(checkpoint_path).parent / "experiment_config.json"
    if experiment_config_path.exists():
        try:
            with open(experiment_config_path, 'r') as f:
                config = json.load(f)
            
            # Try dataset_info.name first (new format), then fall back to args.dataset_name (old format)
            dataset_info = config.get('dataset_info', {})
            dataset_name = dataset_info.get('name')
            if not dataset_name:
                args_info = config.get('args', {})
                dataset_name = args_info.get('dataset_name')
            
            # Extract seed
            seed = config.get('args', {}).get('seed')
            
            if dataset_name:
                print(f"🎯 [GPU {gpu_id}] Using dataset '{dataset_name}' from experiment config")
            if seed is not None:
                print(f"🎯 [GPU {gpu_id}] Using seed {seed} from experiment config")
                
        except (json.JSONDecodeError, KeyError) as e:
            print(f"⚠️  [GPU {gpu_id}] Could not read config from {experiment_config_path}: {e}")
    
    if not dataset_name:
        print(f"❌ [GPU {gpu_id}] Could not determine dataset name for {checkpoint_path}")
        return {
            "job_type": "textfooler",
            "checkpoint": checkpoint_path,
            "gpu_id": gpu_id,
            "return_code": -1,
            "success": False,
            "error": "Could not determine dataset name from experiment config",
            "log_file": None
        }
    
    # Check if attack results already exist to skip redundant work
    attack_results_dir = Path(checkpoint_path).parent / "adv_attack_results" / "TextFoolder"
    existing_results = list(attack_results_dir.glob("max_attack_changes_*")) if attack_results_dir.exists() else []
    # if existing_results:
    #     print(f"⏭️  [GPU {gpu_id}] Skipping TextFooler attack for {checkpoint_path} (results already exist)")
    #     return {
    #         "job_type": "textfooler",
    #         "checkpoint": checkpoint_path,
    #         "gpu_id": gpu_id,
    #         "return_code": 0,
    #         "success": True,
    #         "skipped": True,
    #         "log_file": None
    #     }
    
    # Build command to directly call the TextFooler attack script
    cmd = [
        args.textfooler_python,
        "TextFooler/attack_classification_with_budget.py",
        "--dataset_name", dataset_name,
        "--target_model", "seq_classifier",
        "--target_model_path", checkpoint_path,
        "--counter_fitting_embeddings_path", "data_files/TextFooler/embeddings/counter-fitted-vectors.txt",
        "--counter_fitting_cos_sim_path", "data_files/TextFooler/vocab_cosine_sim/ag_cosine_sim_ag.npy",
        "--USE_cache_path", "data_files/TextFooler/USE",
        "--device", "cuda",
        "--batch_size", "16",
        "--max_seq_length", "256",
        "--use_amp",
    ]
    
    # Add seed if found
    if seed is not None:
        cmd.extend(["--seed", str(seed)])
    else:
        cmd.extend(["--seed", "42"])  # Default seed
    
    # Add configurable TextFooler parameters
    if hasattr(args, 'attack_sample_size') and args.attack_sample_size:
        cmd.extend(["--attack_sample_size", str(args.attack_sample_size)])
    else:
        cmd.extend(["--attack_sample_size", "100"])  # Default
        
    if hasattr(args, 'max_attack_changes') and args.max_attack_changes:
        cmd.extend(["--max_attack_changes", str(args.max_attack_changes)])
    else:
        cmd.extend(["--max_attack_changes", "10"])  # Default
    
    # Add any extra TextFooler arguments
    if hasattr(args, 'textfooler_args') and args.textfooler_args:
        cmd.extend(shlex.split(args.textfooler_args))
    
    # Set up logging
    log_dir = Path(args.logdir) / "textfooler" / f"gpu_{gpu_id}"
    log_dir.mkdir(parents=True, exist_ok=True)
    log_file = log_dir / f"{job_name}.log"
    
    print(f"🚀 [GPU {gpu_id}] Starting TextFooler: {checkpoint_path}")
    
    # Run the job
    with open(log_file, "w") as f:
        f.write(f"Command: {shlex.join(cmd)}\n")
        f.write(f"Working directory: {Path(args.textfooler_root).parent}\n")
        f.write(f"Environment: CUDA_VISIBLE_DEVICES={gpu_id}\n")
        f.write(f"Environment: VIRTUAL_ENV={env['VIRTUAL_ENV']}\n")
        f.write(f"Environment: PATH={env['PATH']}\n")
        f.write(f"Environment: PYTHONPATH={env['PYTHONPATH']}\n\n")
        f.flush()
        
        try:
            result = subprocess.run(
                cmd,
                env=env,
                cwd=Path(args.textfooler_root).parent,  # Run from project root
                stdout=f,
                stderr=subprocess.STDOUT,
                check=False
            )
            return {
                "job_type": "textfooler",
                "checkpoint": checkpoint_path,
                "gpu_id": gpu_id,
                "return_code": result.returncode,
                "success": result.returncode == 0,
                "log_file": str(log_file)
            }
        except Exception as e:
            f.write(f"\nException: {e}\n")
            return {
                "job_type": "textfooler", 
                "checkpoint": checkpoint_path,
                "gpu_id": gpu_id,
                "return_code": -1,
                "success": False,
                "error": str(e),
                "log_file": str(log_file)
            }


def run_a2t_job(checkpoint_path: str, gpu_id: int, args: argparse.Namespace) -> Dict[str, Any]:
    """Run a single A2T attack job"""
    job_name = f"a2t_{Path(checkpoint_path).parent.name}"
    
    # Set up environment with GPU isolation
    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    
    # Extract seed from experiment_config.json
    seed = None
    experiment_config_path = Path(checkpoint_path).parent / "experiment_config.json"
    if experiment_config_path.exists():
        try:
            with open(experiment_config_path, 'r') as f:
                config = json.load(f)
            seed = config.get('args', {}).get('seed')
            if seed is not None:
                print(f"🎯 [GPU {gpu_id}] Using seed {seed} from experiment config")
        except (json.JSONDecodeError, KeyError) as e:
            print(f"⚠️  [GPU {gpu_id}] Could not read seed from {experiment_config_path}: {e}")
    
    # Build command
    cmd = [
        args.a2t_python,
        "evaluate_device_generic.py",
        "--checkpoint-paths", checkpoint_path,
        "--device", "cuda",  # Use "cuda" instead of "cuda:0" 
        "--robustness",
        "--attacks", "a2t",
        "--save-log",
        "--accuracy",
    ]
    
    # Add seed if found
    if seed is not None:
        cmd.extend(["--seed", str(seed)])
    
    # Add configurable A2T parameters
    if args.max_words_changed:
        cmd.extend(["--max-words-changed"] + args.max_words_changed.split())
    if args.num_samples:
        cmd.extend(["--num-samples", str(args.num_samples)])
    if args.query_budget:
        cmd.extend(["--query-budget", str(args.query_budget)])
    
    # Add any extra A2T arguments
    if hasattr(args, 'a2t_args') and args.a2t_args:
        cmd.extend(shlex.split(args.a2t_args))
    
    # Set up logging
    log_dir = Path(args.logdir) / "a2t" / f"gpu_{gpu_id}"
    log_dir.mkdir(parents=True, exist_ok=True)
    log_file = log_dir / f"{job_name}.log"
    
    print(f"🚀 [GPU {gpu_id}] Starting A2T: {checkpoint_path}")
    
    # Run the job
    with open(log_file, "w") as f:
        f.write(f"Command: {shlex.join(cmd)}\n")
        f.write(f"Working directory: {args.a2t_root}\n") 
        f.write(f"Environment: CUDA_VISIBLE_DEVICES={gpu_id}\n\n")
        f.flush()
        
        try:
            result = subprocess.run(
                cmd,
                env=env,
                cwd=args.a2t_root,
                stdout=f,
                stderr=subprocess.STDOUT,
                check=False
            )
            return {
                "job_type": "a2t",
                "checkpoint": checkpoint_path,
                "gpu_id": gpu_id,
                "return_code": result.returncode,
                "success": result.returncode == 0,
                "log_file": str(log_file)
            }
        except Exception as e:
            f.write(f"\nException: {e}\n")
            return {
                "job_type": "a2t",
                "checkpoint": checkpoint_path,
                "gpu_id": gpu_id,
                "return_code": -1,
                "success": False,
                "error": str(e),
                "log_file": str(log_file)
            }


def run_training_job(job_config: Dict[str, Any], gpu_id: int, args: argparse.Namespace) -> Dict[str, Any]:
    """Run a single training job"""
    job_name = f"train_{job_config.get('model', 'unknown')}_{job_config.get('dataset', 'unknown')}"
    
    # Set up environment with GPU isolation
    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    
    # Build command for training
    cmd = [args.training_python, "experiment.py"]
    
    # Add job-specific overrides
    for key, value in job_config.items():
        cmd.append(f"{key}={value}")
    
    # Force device to cuda (will use cuda:0 due to isolation)
    cmd.append("device=cuda")
    
    # Set up logging
    log_dir = Path(args.logdir) / "training" / f"gpu_{gpu_id}"
    log_dir.mkdir(parents=True, exist_ok=True)
    log_file = log_dir / f"{job_name}.log"
    
    print(f"🚀 [GPU {gpu_id}] Starting training: {job_config}")
    
    # Run the job
    with open(log_file, "w") as f:
        f.write(f"Command: {shlex.join(cmd)}\n")
        f.write(f"Job config: {job_config}\n")
        f.write(f"Environment: CUDA_VISIBLE_DEVICES={gpu_id}\n\n")
        f.flush()
        
        try:
            result = subprocess.run(
                cmd,
                env=env,
                stdout=f,
                stderr=subprocess.STDOUT,
                check=False
            )
            return {
                "job_type": "training",
                "config": job_config,
                "gpu_id": gpu_id,
                "return_code": result.returncode,
                "success": result.returncode == 0,
                "log_file": str(log_file)
            }
        except Exception as e:
            f.write(f"\nException: {e}\n")
            return {
                "job_type": "training",
                "config": job_config,
                "gpu_id": gpu_id,
                "return_code": -1,
                "success": False,
                "error": str(e),
                "log_file": str(log_file)
            }


def worker_with_dynamic_gpu(job: Any, job_runner: callable, gpu_queue: queue.Queue, args: argparse.Namespace, killer: GracefulKiller) -> Dict[str, Any]:
    """Generic worker that gets a GPU dynamically and runs a job"""
    if killer.stop:
        return {"success": False, "error": "Cancelled"}
    
    # Get an available GPU
    try:
        gpu_id = gpu_queue.get(timeout=30)
    except queue.Empty:
        return {"success": False, "error": "No GPU available within 30 seconds"}
    
    try:
        # Run the job with the assigned GPU
        return job_runner(job, gpu_id, args)
    finally:
        # Always return the GPU to the queue
        gpu_queue.put(gpu_id)


def run_parallel_jobs(jobs: List[Any], job_runner: callable, gpu_ids: List[int], args: argparse.Namespace) -> List[Dict[str, Any]]:
    """Run jobs in parallel across multiple GPUs with dynamic assignment"""
    
    if not jobs:
        print("No jobs to run!")
        return []
    
    if not gpu_ids:
        print("No GPUs available!")
        return []
    
    print(f"🎯 Running {len(jobs)} jobs across {len(gpu_ids)} GPUs")
    print(f"📋 GPUs: {gpu_ids}")
    
    # Set up GPU queue for dynamic assignment
    gpu_queue = queue.Queue()
    for gpu_id in gpu_ids:
        gpu_queue.put(gpu_id)
    
    killer = GracefulKiller()
    results = []
    
    # Run jobs in parallel
    with ThreadPoolExecutor(max_workers=len(gpu_ids)) as executor:
        # Submit all jobs
        futures = {
            executor.submit(worker_with_dynamic_gpu, job, job_runner, gpu_queue, args, killer): job 
            for job in jobs
        }
        
        # Collect results as they complete
        for future in as_completed(futures):
            if killer.stop:
                print("🛑 Stopping due to shutdown signal")
                break
                
            job = futures[future]
            try:
                result = future.result()
                results.append(result)
                
                # Print status
                status = "✅" if result.get("success") else f"❌ (code: {result.get('return_code', 'unknown')})"
                job_desc = str(job)[:60] + "..." if len(str(job)) > 60 else str(job)
                print(f"[{status}] {job_desc}")
                
            except Exception as e:
                results.append({"success": False, "error": str(e), "job": job})
                print(f"❌ Exception in job {job}: {e}")
    
    return results


def main():
    parser = argparse.ArgumentParser(
        description="Unified parallel executor for ML experiments",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Job type
    parser.add_argument("job_type", choices=["textfooler", "a2t", "training"],
                       help="Type of jobs to run")
    
    # Input sources (mutually exclusive groups)
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument("--search-root", nargs="+", 
                           help="Search for 'final_model' directories under these roots")
    input_group.add_argument("--checkpoint-list", type=str,
                           help="File containing checkpoint paths (one per line)")
    input_group.add_argument("--config-overrides", type=str,
                           help="Training config overrides (e.g., 'model=[gpt2,bert] dataset=[ag_news]')")
    
    # GPU configuration
    parser.add_argument("--gpus", type=str, default=None,
                       help="Comma-separated GPU IDs (default: auto-detect)")
    
    # Paths
    parser.add_argument("--textfooler-root", type=str, default="TextFooler",
                       help="TextFooler project root directory")
    parser.add_argument("--textfooler-python", type=str, 
                       default=str(Path.cwd() / "envs/textfooler/bin/python"),
                       help="Python interpreter for TextFooler")
    parser.add_argument("--a2t-root", type=str, default="TextAttack-A2T", 
                       help="A2T project root directory")
    parser.add_argument("--a2t-python", type=str,
                       default=str(Path.cwd() / "envs/a2t/bin/python"),
                       help="Python interpreter for A2T")
    parser.add_argument("--training-python", type=str,
                       default=str(Path.cwd() / "envs/nlp_training/bin/python"),
                       help="Python interpreter for training")
    
    # A2T specific parameters
    parser.add_argument("--max-words-changed", type=str, default="1 2 3",
                       help="Max words changed for A2T attacks")
    parser.add_argument("--num-samples", type=int, default=100,
                       help="Number of samples for A2T attacks")
    parser.add_argument("--query-budget", type=int, default=50,
                       help="Query budget for A2T attacks")
    
    # TextFooler specific parameters
    parser.add_argument("--attack-sample-size", type=int, default=100,
                       help="Number of samples to attack for TextFooler")
    parser.add_argument("--max-attack-changes", type=int, default=10,
                       help="Maximum number of changes allowed in TextFooler attack")
    parser.add_argument("--textfooler-args", type=str, default="",
                       help="Additional arguments for TextFooler (space-separated)")
    
    # Output
    parser.add_argument("--logdir", type=str, default="logs_parallel",
                       help="Directory for logs and results")
    parser.add_argument("--dry-run", action="store_true",
                       help="Show what would be run without executing")
    
    args = parser.parse_args()
    
    # Discover GPUs
    if args.gpus:
        gpu_ids = [int(x.strip()) for x in args.gpus.split(",")]
    else:
        gpu_ids = get_available_gpus()
    
    print(f"🔧 Using GPUs: {gpu_ids}")
    
    # Prepare jobs based on job type
    if args.job_type in ["textfooler", "a2t"]:
        # Discover checkpoints
        if args.search_root:
            checkpoints = find_final_model_paths(args.search_root)
        else:
            checkpoints = read_checkpoint_list(args.checkpoint_list)
        
        if not checkpoints:
            print("❌ No checkpoints found!")
            return 1
        
        print(f"📁 Found {len(checkpoints)} checkpoints:")
        for ckpt in checkpoints[:5]:  # Show first 5
            print(f"  {ckpt}")
        if len(checkpoints) > 5:
            print(f"  ... and {len(checkpoints) - 5} more")
        
        if args.dry_run:
            print("🏃 Dry run mode - not executing jobs")
            return 0
        
        # Run attacks
        if args.job_type == "textfooler":
            results = run_parallel_jobs(checkpoints, run_textfooler_job, gpu_ids, args)
        else:  # a2t
            results = run_parallel_jobs(checkpoints, run_a2t_job, gpu_ids, args)
            
    else:  # training
        # Parse training configurations
        if not args.config_overrides:
            print("❌ --config-overrides required for training jobs")
            return 1
        
        try:
            train_jobs = parse_config_overrides(args.config_overrides)
        except ValueError as e:
            print(f"❌ Error parsing config overrides: {e}")
            return 1
        
        print(f"🔬 Generated {len(train_jobs)} training jobs:")
        for job in train_jobs[:5]:  # Show first 5
            print(f"  {job}")
        if len(train_jobs) > 5:
            print(f"  ... and {len(train_jobs) - 5} more")
        
        if args.dry_run:
            print("🏃 Dry run mode - not executing jobs")
            return 0
        
        results = run_parallel_jobs(train_jobs, run_training_job, gpu_ids, args)
    
    # Write summary
    summary_file = Path(args.logdir) / "summary.jsonl"
    summary_file.parent.mkdir(parents=True, exist_ok=True)
    
    with open(summary_file, "w") as f:
        for result in results:
            f.write(json.dumps(result) + "\n")
    
    # Print final statistics
    successful = sum(1 for r in results if r.get("success"))
    print(f"\n📊 Results: {successful}/{len(results)} successful")
    print(f"📄 Full summary written to: {summary_file}")
    
    return 0 if successful == len(results) else 1


if __name__ == "__main__":
    sys.exit(main())
