"""Multi-GPU support for Expected GradCAM.

This module provides utilities for distributing computations across multiple GPUs:
- DataParallel wrapper for simple multi-GPU usage
- Distributed Data Parallel (DDP) setup for scalable training
- Device management utilities

For Expected GradCAM, multi-GPU is primarily useful for:
1. Processing multiple images in parallel
2. Handling very large batch sizes (M > 10000)
3. Feature extraction from large models
"""

from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Callable

import torch
from torch import Tensor, nn

if TYPE_CHECKING:
    from torch.distributed import ProcessGroup


def is_multi_gpu_available() -> bool:
    """Check if multiple GPUs are available.

    Returns:
        True if more than one CUDA device is available.
    """
    return torch.cuda.is_available() and torch.cuda.device_count() > 1


def get_device_count() -> int:
    """Get number of available GPUs.

    Returns:
        Number of CUDA devices, or 0 if CUDA is not available.
    """
    if torch.cuda.is_available():
        return torch.cuda.device_count()
    return 0


def get_device_names() -> list[str]:
    """Get names of all available GPUs.

    Returns:
        List of GPU names.
    """
    if not torch.cuda.is_available():
        return []
    return [
        torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
    ]


def setup_distributed(
    rank: int | None = None,
    world_size: int | None = None,
    backend: str = "nccl",
    init_method: str | None = None,
) -> bool:
    """Setup distributed training environment.

    This initializes the process group for DDP training. In most cases,
    this should be called at the start of each worker process.

    Args:
        rank: Rank of current process. If None, reads from RANK env var.
        world_size: Total number of processes. If None, reads from WORLD_SIZE.
        backend: Communication backend ("nccl" for GPU, "gloo" for CPU).
        init_method: URL for process group initialization. If None, uses
            default env:// method.

    Returns:
        True if distributed setup succeeded.

    Example:
        >>> # In each worker process:
        >>> if setup_distributed():
        ...     model = DDP(model, device_ids=[local_rank])
    """
    if not torch.cuda.is_available():
        return False

    # Get rank and world size from environment if not provided
    if rank is None:
        rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
    if world_size is None:
        world_size = int(os.environ.get("WORLD_SIZE", 1))

    if world_size <= 1:
        return False

    if not torch.distributed.is_initialized():
        if init_method is None:
            init_method = "env://"

        torch.distributed.init_process_group(
            backend=backend,
            init_method=init_method,
            rank=rank,
            world_size=world_size,
        )

    return True


def cleanup_distributed() -> None:
    """Cleanup distributed training environment.

    Should be called at the end of distributed training.
    """
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()


class MultiGPUWrapper:
    """Wrapper for multi-GPU computation.

    This class provides a simple interface for running computations across
    multiple GPUs using DataParallel or DistributedDataParallel.

    For Expected GradCAM, this is most useful for:
    1. Batching predictor evaluations across GPUs
    2. Parallel feature extraction from multiple images

    Attributes:
        device_ids: List of GPU device IDs to use.
        primary_device: Primary GPU device for results.
        use_ddp: Whether to use DistributedDataParallel (vs DataParallel).

    Example:
        >>> wrapper = MultiGPUWrapper(device_ids=[0, 1, 2, 3])
        >>> model = wrapper.wrap_module(model)
        >>> # Model now runs on all 4 GPUs
    """

    def __init__(
        self,
        device_ids: list[int] | None = None,
        use_ddp: bool = False,
    ) -> None:
        """Initialize multi-GPU wrapper.

        Args:
            device_ids: List of GPU device IDs to use. If None, uses all
                available GPUs.
            use_ddp: If True, use DistributedDataParallel. If False, use
                DataParallel. DDP is more efficient but requires distributed
                setup.
        """
        if not torch.cuda.is_available():
            self.device_ids = []
            self.primary_device = torch.device("cpu")
            self.use_ddp = False
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))

        self.device_ids = device_ids
        self.primary_device = torch.device(f"cuda:{device_ids[0]}")
        self.use_ddp = use_ddp and torch.distributed.is_initialized()

    @property
    def is_multi_gpu(self) -> bool:
        """Check if using multiple GPUs."""
        return len(self.device_ids) > 1

    def wrap_module(
        self,
        module: nn.Module,
        output_device: int | None = None,
    ) -> nn.Module:
        """Wrap a module for multi-GPU execution.

        Args:
            module: PyTorch module to wrap.
            output_device: Device for gathering outputs. Defaults to first device.

        Returns:
            Wrapped module (DataParallel or DDP).
        """
        if not self.is_multi_gpu:
            return module.to(self.primary_device)

        if output_device is None:
            output_device = self.device_ids[0]

        if self.use_ddp:
            return nn.parallel.DistributedDataParallel(
                module.to(self.primary_device),
                device_ids=[self.device_ids[0]],
                output_device=output_device,
            )
        else:
            return nn.DataParallel(
                module,
                device_ids=self.device_ids,
                output_device=output_device,
            )

    def scatter_tensor(
        self,
        tensor: Tensor,
        dim: int = 0,
    ) -> list[Tensor]:
        """Scatter a tensor across GPUs.

        Args:
            tensor: Tensor to scatter.
            dim: Dimension to scatter along.

        Returns:
            List of tensor chunks, one per GPU.
        """
        if not self.is_multi_gpu:
            return [tensor.to(self.primary_device)]

        chunks = tensor.chunk(len(self.device_ids), dim=dim)
        return [
            chunk.to(f"cuda:{device_id}")
            for chunk, device_id in zip(chunks, self.device_ids)
        ]

    def gather_tensors(
        self,
        tensors: list[Tensor],
        dim: int = 0,
    ) -> Tensor:
        """Gather tensors from multiple GPUs.

        Args:
            tensors: List of tensors from different GPUs.
            dim: Dimension to concatenate along.

        Returns:
            Concatenated tensor on primary device.
        """
        if len(tensors) == 1:
            return tensors[0].to(self.primary_device)

        return torch.cat(
            [t.to(self.primary_device) for t in tensors],
            dim=dim,
        )

    def parallel_apply(
        self,
        fn: Callable[[Tensor], Tensor],
        inputs: list[Tensor],
    ) -> list[Tensor]:
        """Apply function in parallel across GPUs.

        Args:
            fn: Function to apply to each input.
            inputs: List of input tensors (one per GPU).

        Returns:
            List of output tensors.
        """
        if not self.is_multi_gpu:
            return [fn(inputs[0])]

        # Use parallel_apply from torch.nn.parallel
        from torch.nn.parallel import parallel_apply

        # Create module wrappers for the function
        class FnModule(nn.Module):
            def __init__(self, fn: Callable[[Tensor], Tensor]) -> None:
                super().__init__()
                self._fn = fn

            def forward(self, x: Tensor) -> Tensor:
                return self._fn(x)

        modules = [FnModule(fn) for _ in inputs]
        return parallel_apply(modules, inputs)

    def compute_batched(
        self,
        fn: Callable[[Tensor], Tensor],
        input_tensor: Tensor,
        dim: int = 0,
    ) -> Tensor:
        """Compute function on tensor split across GPUs.

        This is a convenience method that handles scattering, parallel
        computation, and gathering.

        Args:
            fn: Function to apply.
            input_tensor: Input tensor to process.
            dim: Dimension to split along.

        Returns:
            Result tensor on primary device.
        """
        # Scatter input
        scattered = self.scatter_tensor(input_tensor, dim)

        # Parallel apply
        results = self.parallel_apply(fn, scattered)

        # Gather results
        return self.gather_tensors(results, dim)


def get_optimal_device_assignment(
    batch_sizes: list[int],
    device_memories_gb: list[float] | None = None,
) -> list[int]:
    """Compute optimal device assignment for variable batch sizes.

    Useful when different computations have different memory requirements.

    Args:
        batch_sizes: Memory requirement (relative) for each computation.
        device_memories_gb: Available memory on each device. If None,
            queries from CUDA.

    Returns:
        List of device IDs for each computation.
    """
    n_devices = get_device_count()
    if n_devices == 0:
        return [0] * len(batch_sizes)

    if device_memories_gb is None:
        device_memories_gb = [
            torch.cuda.get_device_properties(i).total_memory / (1024**3)
            for i in range(n_devices)
        ]

    # Simple greedy assignment: assign to device with most remaining capacity
    remaining_capacity = list(device_memories_gb)
    total_batch = sum(batch_sizes)
    memory_per_unit = min(device_memories_gb) / max(batch_sizes) if batch_sizes else 1

    assignments = []
    for batch in batch_sizes:
        required_memory = batch * memory_per_unit
        # Find device with most remaining capacity that can fit this
        best_device = 0
        best_remaining = -1
        for i, remaining in enumerate(remaining_capacity):
            if remaining >= required_memory and remaining > best_remaining:
                best_device = i
                best_remaining = remaining
        assignments.append(best_device)
        remaining_capacity[best_device] -= required_memory

    return assignments
