"""
Minimal distributed training utilities for torchrun.

This module provides lightweight utilities for multi-GPU training with PyTorch DDP.
Designed to work with torchrun launcher only - no custom wrappers or abstractions.

Usage:
    # Single GPU
    python src/run/main.py --epochs 10
    
    # Multi-GPU (4 GPUs)
    torchrun --nproc_per_node=4 src/run/main.py --epochs 10
"""

import os
from typing import Dict

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


# ============================================================================
# Distributed Setup & Status
# ============================================================================

def setup_distributed() -> None:
    """
    Initialize distributed training.
    Assumes launched with torchrun which sets RANK, WORLD_SIZE, LOCAL_RANK env vars.
    Does nothing if not launched with torchrun (single-GPU mode).
    Idempotent - safe to call multiple times.
    """
    if not is_distributed_launch():
        return
    
    # Skip if already initialized (allows reuse across multiple run() calls)
    if dist.is_initialized():
        return
    
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    
    # Set CUDA device BEFORE any CUDA operations
    torch.cuda.set_device(local_rank)
    
    # Standard NCCL init - no special device_id needed for single-node
    dist.init_process_group(backend="nccl")
    
    if rank == 0:
        print(f"Initialized distributed training: {get_world_size()} GPUs")


def is_distributed_launch() -> bool:
    """Check if launched with torchrun by looking for environment variables."""
    return 'RANK' in os.environ and 'WORLD_SIZE' in os.environ


def is_distributed() -> bool:
    """Check if currently running in distributed mode."""
    return dist.is_available() and dist.is_initialized()


def get_rank() -> int:
    """Get current process rank. Returns 0 for single-GPU."""
    if is_distributed():
        return dist.get_rank()
    return 0


def get_world_size() -> int:
    """Get total number of processes. Returns 1 for single-GPU."""
    if is_distributed():
        return dist.get_world_size()
    return 1


def is_main_process() -> bool:
    """Check if this is rank 0 (main process)."""
    return get_rank() == 0


def barrier() -> None:
    """Synchronize all processes. Does nothing in single-GPU mode."""
    if is_distributed():
        # Specify device_ids to avoid PyTorch warning
        device = torch.cuda.current_device()
        dist.barrier(device_ids=[device])


def cleanup_distributed() -> None:
    """Clean up distributed process group. Safe to call in single-GPU mode."""
    if is_distributed():
        dist.destroy_process_group()


# ============================================================================
# Model Unwrapping Helper
# ============================================================================

def get_raw_model(model):
    """
    Get underlying model, unwrapping DDP and torch.compile if needed.
    
    Handles wrapping order: DDP(CompiledModel(BaseModel))
    (compile-before-DDP is the PyTorch recommended pattern)
    
    Use this when you need to access:
    - model.config
    - Custom methods like model.get_params(), model.ablate()
    - Model attributes like model.model_type
    
    Examples:
        raw_model = get_raw_model(model)
        embed_dim = raw_model.config.embed_dim
        params = raw_model.get_params('core')
    """
    # Unwrap DDP first (outermost wrapper)
    if isinstance(model, DDP):
        model = model.module
    
    # Then unwrap torch.compile (check if it has _orig_mod attribute)
    if hasattr(model, '_orig_mod'):
        model = model._orig_mod
    
    return model


# ============================================================================
# Collective Operations
# ============================================================================

def reduce_tensor(tensor, average: bool = True):
    """
    All-reduce a tensor across processes.
    
    Use this to aggregate metrics (loss, accuracy) across GPUs.
    Does nothing in single-GPU mode.
    
    Example:
        loss_tensor = torch.tensor(loss, device=model.device)
        avg_loss = reduce_tensor(loss_tensor).item()
    """
    if not is_distributed():
        return tensor
    
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    if average:
        tensor /= get_world_size()
    return tensor


def broadcast_object(obj, src: int = 0):
    """
    Broadcast any picklable object from source rank to all ranks.
    
    Use this to synchronize random state, sampled labels, etc.
    Does nothing in single-GPU mode.
    
    Example:
        # Main process samples, others receive the same value
        label = random.choice(labels) if is_main_process() else None
        label = broadcast_object(label)  # Now all ranks have same label
    """
    if not is_distributed():
        return obj
    
    obj_list = [obj if get_rank() == src else None]
    dist.broadcast_object_list(obj_list, src=src)
    return obj_list[0]


def all_reduce_dict(data_dict: Dict) -> Dict:
    """
    All-reduce a dictionary of tensors across all processes.
    
    Args:
        data_dict: Dictionary of tensors to reduce
        
    Returns:
        Dictionary with reduced tensors
    """
    if not (dist.is_available() and dist.is_initialized()):
        return data_dict
    
    world_size = get_world_size()
    if world_size == 1:
        return data_dict
    
    # Reduce each tensor in the dictionary
    for key, tensor in data_dict.items():
        if torch.is_tensor(tensor):
            dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
            data_dict[key] = tensor / world_size
    
    return data_dict