import os
import time
import logging
import fcntl
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Tuple, Optional
import multiprocessing as mp
from pathlib import Path    
import json

# Set start method to spawn for CUDA compatibility
mp.set_start_method('spawn', force=True)


def update_configs_json(configs_file: Path, ref_str: str, config: Dict[str, Any]) -> None:
    """
    Thread-safe function to update configs.json with a new entry using file locking.
    
    Args:
        configs_file: Path to the configs.json file
        ref_str: Reference string (key) for this configuration
        config: Configuration dictionary to store
    """
    configs_file.parent.mkdir(parents=True, exist_ok=True)
    mode = 'r+' if configs_file.exists() else 'w+'
    with open(configs_file, mode) as f:
        fcntl.flock(f.fileno(), fcntl.LOCK_EX)  # Exclusive lock
        try:
            content = f.read()
            configs_dict = json.loads(content) if content.strip() else {}
        except json.JSONDecodeError:
            configs_dict = {}
        
        configs_dict[ref_str] = config
        f.seek(0)
        f.truncate()
        json.dump(configs_dict, f, indent=2, default=str)


def worker_run(task: Tuple[Dict[str, Any], int, str, Optional[mp.Queue]]) -> Tuple[Dict[str, Any], float]:
    """
    Worker function that runs a single configuration on a specific GPU.
    
    Args:
        task: Tuple of (config dict, gpu_id, log_level, log_queue, res_root)
    
    Returns:
        Tuple of (config dict, execution time)
    """
    config, gpu_id, log_level, log_queue, res_root = task
    
    # CRITICAL: Set CUDA_VISIBLE_DEVICES BEFORE importing torch-dependent modules!
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    
    # Now safe to import modules that use torch
    from src.run.utils import get_timestamp
    from src.run.main import run
    from src.run.logger import setup_logger
    
    # Setup worker logger
    worker_logger = setup_logger(
        name=f"worker_gpu_{gpu_id}",
        level=log_level,
        process_id=gpu_id,
        multiprocessing_queue=log_queue,
    )

    start_time = time.time()
    
    try:
        # Add log_level and process_id to config
        config['log_level'] = log_level
        config['process_id'] = gpu_id
        
        # Create GPU-specific subdirectory in res_root
        
        main_res_root = Path(res_root)
        gpu_res_root = main_res_root / f"gpu_{gpu_id}"

        timestamp = get_timestamp()
        res_dir = gpu_res_root / f"results_{timestamp}"

        config['res_dir'] = res_dir
        config['timestamp'] = timestamp

        ref_str = f"gpu_{gpu_id}/{res_dir.name}"

        worker_logger.info(f"Starting run at {ref_str}")

        # Run the actual training
        run(**config)
        
        duration = time.time() - start_time
        worker_logger.info(f"Completed run at {ref_str}, took {format_time(duration)}")

        # Add ref_str->config to main_res_root/configs.json (thread-safe via file locking)
        configs_file = main_res_root / "configs.json"
        update_configs_json(configs_file, ref_str, config)
        
        return config, duration

    except Exception as e:
        worker_logger.error(f"ERROR in run at {ref_str}")
        worker_logger.error(f"Error details: {e}")
        raise


def format_time(seconds: float) -> str:
    """Format seconds into HH:MM:SS string."""
    hours, remainder = divmod(int(seconds), 3600)
    minutes, seconds = divmod(remainder, 60)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"


def get_available_gpus() -> int:
    """
    Detect the number of available GPUs.
    
    Returns:
        Number of available GPUs
    """
    try:
        import torch
        if torch.cuda.is_available():
            return torch.cuda.device_count()
        else:
            logging.warning("No CUDA devices available. Running on CPU.")
            return 1
    except ImportError:
        logging.warning("PyTorch not found. Assuming 1 GPU.")
        return 1


def run_parallel(
    configs: list[Dict[str, Any]],
    res_root: str,
    log_level: str = "INFO",
    num_gpus: int = -1,
) -> None:

    from src.run.logger import setup_logger, setup_multiprocess_logging
    
    res_root = Path(res_root)
    res_root.mkdir(parents=True, exist_ok=True)

    # Setup main process logger and multiprocessing logging
    log_file = res_root / "orchestrate.log"
    log_queue, queue_listener, manager = setup_multiprocess_logging(log_file, log_level)
    queue_listener.start()
    
    main_logger = setup_logger(
        name="orchestrate_main",
        log_file=log_file,
        level=log_level,
    )

    # Configuration
    if num_gpus == -1:
        num_gpus = get_available_gpus()
    else:
        available = get_available_gpus()
        if num_gpus > available:
            main_logger.warning(f"Requested {num_gpus} GPUs, but only {available} are available.")
            num_gpus = min(num_gpus, available)
    
    main_logger.info(f"Total configurations to run: {len(configs)}")
    main_logger.info(f"Using {num_gpus} GPUs for parallel execution")
    main_logger.info("=" * 100)
    
    # Track progress
    completed = 0
    failed = 0
    start_time = time.time()
    
    # Create executor with dynamic GPU assignment
    with ProcessPoolExecutor(max_workers=num_gpus) as executor:
        # Track which GPU is handling which task
        gpu_queue = list(range(num_gpus))  # Available GPUs
        active_futures = {}  # Maps future to (config, gpu_id)
        pending_configs = configs.copy()
        
        # Initial submission - submit tasks up to number of GPUs
        for _ in range(min(num_gpus, len(pending_configs))):
            if pending_configs and gpu_queue:
                config = pending_configs.pop(0)
                gpu_id = gpu_queue.pop(0)
                future = executor.submit(worker_run, (config, gpu_id, log_level, log_queue, res_root))
                active_futures[future] = (config, gpu_id)
        
        # Process completed tasks and submit new ones
        while active_futures:
            # Wait for any task to complete
            for done in as_completed(active_futures):
                config, gpu_id = active_futures[done]
                del active_futures[done]
                break  # Process one at a time to maintain GPU assignment
            
            # Process result
            try:
                _, duration = done.result()
                completed += 1
                
                # GPU is now free, submit next task if available
                if pending_configs:
                    next_config = pending_configs.pop(0)
                    future = executor.submit(worker_run, (next_config, gpu_id, log_level, log_queue, res_root))
                    active_futures[future] = (next_config, gpu_id)
                else:
                    # No more tasks, GPU is idle
                    gpu_queue.append(gpu_id)
                    
            except Exception as e:
                failed += 1
                main_logger.error(f"Task failed on GPU {gpu_id}: {e}")
                
                # GPU is free for next task despite failure
                if pending_configs:
                    next_config = pending_configs.pop(0)
                    future = executor.submit(worker_run, (next_config, gpu_id, log_level, log_queue, res_root))
                    active_futures[future] = (next_config, gpu_id)
                else:
                    gpu_queue.append(gpu_id)
            
            # Progress update
            elapsed = time.time() - start_time
            avg_time_per_run = elapsed / completed if completed > 0 else 0
            remaining_runs = len(pending_configs) + len(active_futures)
            # Use observed throughput across all GPUs (elapsed/completed) without dividing by GPU count again
            estimated_remaining = (avg_time_per_run * remaining_runs) if (remaining_runs > 0 and completed > 0) else 0
            
            main_logger.info(f"PROGRESS: {completed}/{len(configs)} completed, {failed} failed")
            main_logger.info(f"  Active tasks: {len(active_futures)}")
            main_logger.info(f"  Pending tasks: {len(pending_configs)}")
            main_logger.info(f"  Elapsed time: {format_time(elapsed)}")
            main_logger.info(f"  Average time per run: {format_time(avg_time_per_run)}")
            main_logger.info(f"  Estimated remaining: {format_time(estimated_remaining)}")
            main_logger.info(f"  Estimated total: {format_time(elapsed + estimated_remaining)}")
            main_logger.info("=" * 100)
    
    # Final summary
    main_logger.info("=" * 100)
    main_logger.info("EXECUTION COMPLETE")
    main_logger.info(f"  Total runs: {len(configs)}")
    main_logger.info(f"  Successful: {completed}")
    main_logger.info(f"  Failed: {failed}")
    main_logger.info(f"  Total time: {format_time(time.time() - start_time)}")
    main_logger.info(f"  Output directory: {res_root}")
    main_logger.info("=" * 100)
    
    # Stop the queue listener and shutdown manager
    queue_listener.stop()
    manager.shutdown()