"""
GPU-accelerated parallel training script for fine-tuning experiments.

GPU Configuration Options (in priority order):
1. Config file: Set 'gpus: [0, 1, 2]' in config.yaml
2. Command line: Use '+gpus=[0,1]' override
3. Environment: Set CUDA_VISIBLE_DEVICES=0,1
4. Auto-detection: Automatically detects available GPUs

Examples:
  # Use specific GPUs via command line
  python experiment_parallel.py +gpus='[0,2]' train.strategy='[full,lora]'
  
  # Use environment variable
  CUDA_VISIBLE_DEVICES=1 python experiment_parallel.py
  
  # Auto-detect all available GPUs
  python experiment_parallel.py train.strategy='[full,head_only,lora]'
"""

import os
import sys
import subprocess
from itertools import product
from concurrent.futures import ThreadPoolExecutor, as_completed
import queue
import threading

import hydra
from omegaconf import DictConfig, OmegaConf, ListConfig


def get_available_gpus() -> list[int]:
    """
    Auto-detect available GPUs. Falls back to environment variable or default.
    
    Priority:
    1. CUDA_VISIBLE_DEVICES environment variable
    2. Auto-detection via nvidia-ml-py3 or nvidia-smi
    3. Default fallback: [0, 1]
    """
    # Option 1: Check environment variable first
    cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES')
    if cuda_visible:
        try:
            if cuda_visible.strip() == "":
                return []  # No GPUs available
            return [int(gpu_id.strip()) for gpu_id in cuda_visible.split(',') if gpu_id.strip().isdigit()]
        except ValueError:
            print(f"Warning: Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible}")
    
    # Option 2: Try to auto-detect using nvidia-ml-py3
    try:
        import pynvml
        pynvml.nvmlInit()
        gpu_count = pynvml.nvmlDeviceGetCount()
        available_gpus = list(range(gpu_count))
        print(f"Auto-detected {gpu_count} GPUs: {available_gpus}")
        return available_gpus
    except ImportError:
        print("pynvml not available, trying nvidia-smi...")
    except Exception as e:
        print(f"pynvml detection failed: {e}")
    
    # Option 3: Try nvidia-smi as fallback
    try:
        result = subprocess.run(['nvidia-smi', '--list-gpus'], 
                              capture_output=True, text=True, check=True)
        gpu_lines = [line for line in result.stdout.strip().split('\n') if line.strip()]
        gpu_count = len(gpu_lines)
        available_gpus = list(range(gpu_count))
        print(f"Detected {gpu_count} GPUs via nvidia-smi: {available_gpus}")
        return available_gpus
    except (subprocess.CalledProcessError, FileNotFoundError) as e:
        print(f"nvidia-smi detection failed: {e}")
    
    # Option 4: Default fallback
    default_gpus = [0, 1]
    print(f"Falling back to default GPUs: {default_gpus}")
    print("To override, set CUDA_VISIBLE_DEVICES environment variable or install pynvml")
    return default_gpus


def build_train_cmd(cfg: DictConfig, output_dir: str) -> list[str]:
    """
    Construct the command to launch your training script based on the strategy.
    """
    project_root = os.path.dirname(os.path.abspath(__file__))
    strategy = cfg.train.strategy
    script_name = "finetune_v2.py" if strategy in ("full", "head_only") else "lora_v2.py"
    script_path = os.path.join(project_root, "nlp_training", script_name)
    
    print(f"DEBUG: Strategy '{strategy}' -> Using script: {script_path}")

    base_args = [
        sys.executable,
        script_path,
        "--dataset_name", cfg.dataset.name,
        "--model_name", cfg.model.name,
        "--output_dir", output_dir,
        "--num_epochs", str(cfg.train.num_epochs),
        "--batch_size", str(cfg.model.batch_size),
        "--learning_rate", str(cfg.train.learning_rate),
        "--seed", str(cfg.seed),
        "--max_seq_length", str(cfg.train.max_seq_length),
        "--hf_token", str(cfg.hf_token),
    ]
    if strategy == "head_only":
        base_args.append("--freeze_base")
    if strategy == "lora":
        # Add any LoRA-specific args
        if getattr(cfg.train, "lora_r", None) is not None:
            base_args += ["--lora_r", str(cfg.train.lora_r)]
        if getattr(cfg.train, "lora_alpha", None) is not None:
            base_args += ["--lora_alpha", str(cfg.train.lora_alpha)]
        if getattr(cfg.train, "lora_dropout", None) is not None:
            base_args += ["--lora_dropout", str(cfg.train.lora_dropout)]
        if getattr(cfg.train, "lora_target_modules", None) is not None:
            # Convert list to space-separated arguments
            modules = cfg.train.lora_target_modules
            # Handle both OmegaConf ListConfig and regular Python lists/tuples
            if isinstance(modules, (list, tuple)) or hasattr(modules, '__iter__'):
                # Convert to list and ensure all elements are strings
                module_list = [str(m) for m in modules]
                base_args += ["--lora_target_modules"] + module_list
            else:
                base_args += ["--lora_target_modules", str(modules)]
        
        # Add LoRA initialization parameters
        if getattr(cfg.train, "lora_init_type", None) is not None:
            base_args += ["--lora_init_type", str(cfg.train.lora_init_type)]
        if getattr(cfg.train, "lora_init_scale", None) is not None:
            base_args += ["--lora_init_scale", str(cfg.train.lora_init_scale)]
    if getattr(cfg, 'debug', False):
        base_args.append("--debug")
    
    # print(f"base_args: {base_args}")
    # assert False, 'breakpoint'
    
    return base_args


def worker_with_dynamic_gpu(job_params: dict, base_cfg: DictConfig, gpu_queue: queue.Queue) -> tuple[dict, int]:
    """
    Runs one training job on a dynamically assigned GPU.
    
    Thread-safe: Uses queue.Queue which provides atomic get/put operations
    to prevent multiple threads from getting the same GPU ID.
    """
    # Get an available GPU from the queue (thread-safe operation)
    try:
        gpu_id = gpu_queue.get(timeout=30)  # Add timeout to prevent infinite blocking
    except queue.Empty:
        raise RuntimeError("No GPU became available within 30 seconds")
    
    job_params = job_params.copy()  # Don't modify the original
    job_params["gpu_id"] = gpu_id
    
    try:
        # Deep-copy the base config so each job is isolated
        cfg = OmegaConf.create(OmegaConf.to_container(base_cfg, resolve=True))

        # Load the specific model config to get the actual HuggingFace model name
        model_config_path = f"conf/model/{job_params['model']}.yaml"
        model_cfg = OmegaConf.load(model_config_path)
        
        # Load the specific training config to get strategy-specific parameters
        train_config_path = f"conf/train/{job_params['strategy']}.yaml"
        train_cfg = OmegaConf.load(train_config_path)
        
        # Override parameters for this job
        cfg.train = train_cfg  # Use the complete training config
        cfg.dataset.name = job_params["dataset"]
        cfg.model.name = model_cfg.name  # Use actual HuggingFace model name
        cfg.model.batch_size = model_cfg.batch_size  # Also update batch size
        cfg.seed = job_params["seed"]

        # print(f"cfg: {cfg}")
        # assert False, 'breakpoint'

        # Compose a unique output directory per job
        model_name_safe = job_params["model"].replace('.', '_').replace('/', '_')
        output_dir = (
            f"{cfg.output_dir}/{job_params['dataset']}/"
            f"{model_name_safe}_epochs_{cfg.train.num_epochs}_seed_{job_params['seed']}/"
            f"{job_params['strategy']}/gpu_{job_params['gpu_id']}"
        )
        os.makedirs(output_dir, exist_ok=True)

        # Build and launch the training command
        cmd = build_train_cmd(cfg, output_dir)

        # print(f"cmd: {cmd}")
        # assert False, 'breakpoint'

        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(job_params["gpu_id"])

        print(f"[→] Launching job {job_params} on GPU {job_params['gpu_id']}: {' '.join(cmd)}")
        result = subprocess.run(cmd, env=env, cwd=os.path.dirname(os.path.abspath(__file__)))
        return job_params, result.returncode
    finally:
        # Always return the GPU to the queue when done
        gpu_queue.put(gpu_id)


# def worker(job_params: dict, base_cfg: DictConfig) -> tuple[dict, int]:
#     """
#     Runs one training job on the assigned GPU.
#     """
#     # Deep-copy the base config so each job is isolated
#     cfg = OmegaConf.create(OmegaConf.to_container(base_cfg, resolve=True))

#     # Load the specific model config to get the actual HuggingFace model name
#     model_config_path = f"conf/model/{job_params['model']}.yaml"
#     model_cfg = OmegaConf.load(model_config_path)
    
#     # Load the specific training config to get strategy-specific parameters
#     train_config_path = f"conf/train/{job_params['strategy']}.yaml"
#     train_cfg = OmegaConf.load(train_config_path)

#     # Override parameters for this job
#     cfg.train = train_cfg  # Use the complete training config
#     cfg.dataset.name = job_params["dataset"]
#     cfg.model.name = model_cfg.name  # Use actual HuggingFace model name
#     cfg.model.batch_size = model_cfg.batch_size  # Also update batch size
#     cfg.seed = job_params["seed"]

#     # Compose a unique output directory per job
#     model_name_safe = job_params["model"].replace('.', '_').replace('/', '_')
#     output_dir = (
#         f"{cfg.output_dir}/{job_params['dataset']}/"
#         f"{model_name_safe}_epochs_{cfg.train.num_epochs}_seed_{job_params['seed']}/"
#         f"{job_params['strategy']}/gpu_{job_params['gpu_id']}"
#     )
#     os.makedirs(output_dir, exist_ok=True)

#     # Build and launch the training command
#     cmd = build_train_cmd(cfg, output_dir)
#     env = os.environ.copy()
#     env["CUDA_VISIBLE_DEVICES"] = str(job_params["gpu_id"])

#     print(f"[→] Launching job {job_params} on GPU {job_params['gpu_id']}: {' '.join(cmd)}")
#     result = subprocess.run(cmd, env=env, cwd=os.path.dirname(os.path.abspath(__file__)))
#     return job_params, result.returncode


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
    # Get available GPUs dynamically
    # Priority: 1. Config file, 2. Auto-detection
    if hasattr(cfg, 'gpus') and cfg.gpus:
        GPUS = list(cfg.gpus) if isinstance(cfg.gpus, (list, ListConfig)) else [cfg.gpus]
        print(f"Using GPUs from config: {GPUS}")
    else:
        GPUS = get_available_gpus()
    
    if not GPUS:
        raise RuntimeError("No GPUs available! Check CUDA_VISIBLE_DEVICES, config.gpus, or GPU installation.")
    
    print(f"Using GPUs: {GPUS}")
    
    # Flatten config entries that might be lists or singletons
    strategies = list(cfg.train.strategy) if isinstance(cfg.train.strategy, (list, ListConfig)) else [cfg.train.strategy]
    datasets = list(cfg.dataset.name) if isinstance(cfg.dataset.name, (list, ListConfig)) else [cfg.dataset.name]
    models = list(cfg.model.name) if isinstance(cfg.model.name, (list, ListConfig)) else [cfg.model.name]
    seeds = list(cfg.seed) if isinstance(cfg.seed, (list, ListConfig)) else [cfg.seed]

    # Build the full list of jobs (cartesian product)
    combos = list(product(strategies, datasets, models, seeds))
    jobs = [
        {
            "strategy": s,
            "dataset": d,
            "model": m,
            "seed": sd,
            # No pre-assigned GPU - will be assigned dynamically
        }
        for s, d, m, sd in combos
    ]

    print(f"Created {len(jobs)} jobs to be distributed across {len(GPUS)} GPUs")

    # Create a queue of available GPUs
    gpu_queue = queue.Queue()
    for gpu_id in GPUS:
        gpu_queue.put(gpu_id)

    # Dispatch with dynamic GPU assignment
    with ThreadPoolExecutor(max_workers=len(GPUS)) as executor:
        futures = {executor.submit(worker_with_dynamic_gpu, job, cfg, gpu_queue): job for job in jobs}
        for fut in as_completed(futures):
            job, code = fut.result()
            status = "✅" if code == 0 else f"❌ (exit {code})"
            print(f"[{status}] Finished job {job} on GPU {job['gpu_id']}")


if __name__ == "__main__":
    main()
