"""Automatic Mixed Precision (AMP) management for Expected GradCAM.

This module provides utilities for managing AMP (FP16/BF16) during gradient
computation. AMP provides:
- 1.5-2x speedup on compatible GPUs
- ~50% memory reduction
- Maintained numerical accuracy for gradients

Key design: Forward pass uses autocast (FP16), backward pass uses FP32 for
proper gradient computation.
"""

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator

import torch

if TYPE_CHECKING:
    from torch import dtype as TorchDtype

# Check AMP availability
_AMP_AVAILABLE = hasattr(torch.cuda, "amp") and torch.cuda.is_available()

# Check BF16 support (Ampere+ GPUs)
_BF16_AVAILABLE = (
    _AMP_AVAILABLE
    and torch.cuda.is_available()
    and torch.cuda.get_device_capability()[0] >= 8
)


def is_amp_available() -> bool:
    """Check if AMP is available on this system.

    Returns:
        True if CUDA and AMP are available.
    """
    return _AMP_AVAILABLE


def is_bf16_available() -> bool:
    """Check if BF16 is available on this system.

    Requires Ampere or newer GPU (compute capability >= 8.0).

    Returns:
        True if BF16 is supported.
    """
    return _BF16_AVAILABLE


def get_amp_dtype(prefer_bf16: bool = True) -> "TorchDtype":
    """Get the appropriate AMP dtype for this system.

    Args:
        prefer_bf16: If True and BF16 is available, use BF16. Otherwise FP16.

    Returns:
        torch.float16 or torch.bfloat16.
    """
    if prefer_bf16 and _BF16_AVAILABLE:
        return torch.bfloat16
    return torch.float16


class AMPContext:
    """Context manager for AMP operations.

    This class provides a convenient way to manage AMP for gradient computation.
    The forward pass uses autocast for speed, while the backward pass stays in
    FP32 for numerical stability.

    Example:
        >>> amp_ctx = AMPContext(enabled=True, dtype=torch.float16)
        >>> with amp_ctx.forward_context():
        ...     output = model(input)
        >>> # Gradients computed outside autocast (FP32)
        >>> grads = torch.autograd.grad(output.sum(), input)

    Attributes:
        enabled: Whether AMP is enabled.
        dtype: The dtype to use for autocast (float16 or bfloat16).
        device_type: The device type for autocast ("cuda").
    """

    def __init__(
        self,
        enabled: bool = True,
        dtype: "TorchDtype | None" = None,
        prefer_bf16: bool = True,
    ) -> None:
        """Initialize AMP context.

        Args:
            enabled: Whether to enable AMP. If False, autocast is a no-op.
            dtype: Explicit dtype to use. If None, auto-selects based on GPU.
            prefer_bf16: If dtype is None and BF16 is available, use BF16.
        """
        self.enabled = enabled and _AMP_AVAILABLE
        self.dtype = dtype if dtype is not None else get_amp_dtype(prefer_bf16)
        self.device_type = "cuda"

    @contextmanager
    def forward_context(self) -> Iterator[None]:
        """Context manager for forward pass with autocast.

        Use this for forward passes where you want FP16/BF16 acceleration.
        Gradients should be computed outside this context for stability.

        Yields:
            None
        """
        if self.enabled:
            with torch.amp.autocast(device_type=self.device_type, dtype=self.dtype):
                yield
        else:
            yield

    @contextmanager
    def inference_context(self) -> Iterator[None]:
        """Context manager for inference (no gradients needed).

        Combines autocast with no_grad for maximum efficiency during inference.

        Yields:
            None
        """
        with torch.no_grad():
            if self.enabled:
                with torch.amp.autocast(
                    device_type=self.device_type, dtype=self.dtype
                ):
                    yield
            else:
                yield

    def __repr__(self) -> str:
        return f"AMPContext(enabled={self.enabled}, dtype={self.dtype})"


@contextmanager
def amp_autocast(
    enabled: bool = True,
    dtype: "TorchDtype | None" = None,
    device: torch.device | str | None = None,
) -> Iterator[None]:
    """Simple autocast context manager.

    A convenience wrapper around torch.amp.autocast that handles availability
    checking and device type detection.

    Args:
        enabled: Whether to enable autocast.
        dtype: Dtype to use. Defaults to float16.
        device: Device to check. If CUDA, uses GPU autocast.

    Yields:
        None

    Example:
        >>> with amp_autocast(enabled=True):
        ...     output = model(input)
    """
    # Determine if we should actually use autocast
    if device is not None:
        device = torch.device(device) if isinstance(device, str) else device
        use_autocast = enabled and device.type == "cuda" and _AMP_AVAILABLE
    else:
        use_autocast = enabled and _AMP_AVAILABLE

    if use_autocast:
        amp_dtype = dtype if dtype is not None else torch.float16
        with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
            yield
    else:
        yield


class GradScaler:
    """Wrapper around torch.cuda.amp.GradScaler with availability checking.

    For Expected GradCAM, we typically don't need gradient scaling since we
    compute gradients via torch.autograd.grad() rather than loss.backward().
    This class is provided for completeness and for users who may have custom
    training loops.

    Example:
        >>> scaler = GradScaler(enabled=True)
        >>> with amp_autocast():
        ...     output = model(input)
        ...     loss = criterion(output, target)
        >>> scaler.scale(loss).backward()
        >>> scaler.step(optimizer)
        >>> scaler.update()
    """

    def __init__(self, enabled: bool = True, init_scale: float = 2.0**16) -> None:
        """Initialize gradient scaler.

        Args:
            enabled: Whether to enable scaling.
            init_scale: Initial scale factor.
        """
        self.enabled = enabled and _AMP_AVAILABLE
        if self.enabled:
            self._scaler = torch.amp.GradScaler("cuda", init_scale=init_scale)
        else:
            self._scaler = None

    def scale(self, loss: torch.Tensor) -> torch.Tensor:
        """Scale the loss for backward pass."""
        if self._scaler is not None:
            return self._scaler.scale(loss)
        return loss

    def step(self, optimizer: torch.optim.Optimizer) -> None:
        """Step the optimizer with unscaling."""
        if self._scaler is not None:
            self._scaler.step(optimizer)
        else:
            optimizer.step()

    def update(self) -> None:
        """Update the scale factor."""
        if self._scaler is not None:
            self._scaler.update()

    def get_scale(self) -> float:
        """Get current scale factor."""
        if self._scaler is not None:
            return self._scaler.get_scale()
        return 1.0
