"""Memory estimation and management utilities for Expected GradCAM.

This module provides utilities for:
- Estimating GPU memory requirements
- Computing optimal batch sizes
- Monitoring memory usage during computation
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    pass


@dataclass
class MemoryInfo:
    """GPU memory information.

    Attributes:
        total: Total GPU memory in bytes.
        allocated: Currently allocated memory in bytes.
        reserved: Reserved memory in bytes (includes cached).
        available: Available memory in bytes.
    """

    total: int
    allocated: int
    reserved: int
    available: int

    @property
    def total_gb(self) -> float:
        """Total memory in GB."""
        return self.total / (1024**3)

    @property
    def allocated_gb(self) -> float:
        """Allocated memory in GB."""
        return self.allocated / (1024**3)

    @property
    def available_gb(self) -> float:
        """Available memory in GB."""
        return self.available / (1024**3)

    @property
    def utilization(self) -> float:
        """Memory utilization as a fraction (0-1)."""
        return self.allocated / self.total if self.total > 0 else 0.0


def get_available_memory(device: torch.device | int | None = None) -> MemoryInfo:
    """Get current GPU memory information.

    Args:
        device: GPU device to query. If None, uses current device.

    Returns:
        MemoryInfo with current memory state.

    Raises:
        RuntimeError: If CUDA is not available.
    """
    if not torch.cuda.is_available():
        # Return zeros for CPU
        return MemoryInfo(total=0, allocated=0, reserved=0, available=0)

    if device is None:
        device = torch.cuda.current_device()
    elif isinstance(device, torch.device):
        device = device.index if device.index is not None else 0

    total = torch.cuda.get_device_properties(device).total_memory
    allocated = torch.cuda.memory_allocated(device)
    reserved = torch.cuda.memory_reserved(device)
    available = total - reserved

    return MemoryInfo(
        total=total,
        allocated=allocated,
        reserved=reserved,
        available=available,
    )


class MemoryEstimator:
    """Estimate memory requirements for Expected GradCAM computations.

    This class provides methods to estimate memory requirements for different
    batch sizes and configurations, helping to avoid OOM errors.

    Attributes:
        K: Number of feature channels.
        bytes_per_element: Bytes per tensor element (4 for FP32, 2 for FP16).
        safety_margin: Fraction of available memory to leave unused (0.1 = 10%).
    """

    def __init__(
        self,
        K: int,
        dtype: torch.dtype = torch.float32,
        safety_margin: float = 0.1,
    ) -> None:
        """Initialize memory estimator.

        Args:
            K: Number of feature channels.
            dtype: Tensor dtype for memory calculation.
            safety_margin: Fraction of memory to leave unused (0.0-0.5).
        """
        self.K = K
        self.dtype = dtype
        self.bytes_per_element = torch.tensor([], dtype=dtype).element_size()
        self.safety_margin = max(0.0, min(0.5, safety_margin))

    def estimate_perturbation_memory(self, M: int) -> int:
        """Estimate memory for storing perturbation samples.

        Args:
            M: Number of perturbation samples.

        Returns:
            Memory requirement in bytes.
        """
        # I_samples: [M, K]
        return M * self.K * self.bytes_per_element

    def estimate_attribution_memory(self, M: int) -> int:
        """Estimate memory for storing attribution samples.

        Args:
            M: Number of perturbation samples.

        Returns:
            Memory requirement in bytes.
        """
        # phi_samples: [M, K]
        return M * self.K * self.bytes_per_element

    def estimate_moment_matrix_memory(self) -> int:
        """Estimate memory for second moment matrix.

        Returns:
            Memory requirement in bytes.
        """
        # M_I: [K, K]
        return self.K * self.K * self.bytes_per_element

    def estimate_batch_forward_memory(
        self,
        batch_size: int,
        N: int = 20,
        T: int = 50,
    ) -> int:
        """Estimate memory for a single batch forward pass.

        This includes interpolation points, outputs, and gradients.

        Args:
            batch_size: Number of points in forward pass.
            N: Baseline samples per perturbation.
            T: Integration steps.

        Returns:
            Memory requirement in bytes.
        """
        # z_flat: [batch_size, K] - interpolation points
        points_mem = batch_size * self.K * self.bytes_per_element

        # outputs: [batch_size] - scalar outputs
        outputs_mem = batch_size * self.bytes_per_element

        # grads: [batch_size, K] - gradients (same size as inputs)
        grads_mem = batch_size * self.K * self.bytes_per_element

        # Autograd graph overhead (rough estimate: 2x tensor size)
        autograd_overhead = 2 * (points_mem + grads_mem)

        return points_mem + outputs_mem + grads_mem + autograd_overhead

    def estimate_total_memory(
        self,
        M: int,
        N: int = 20,
        T: int = 50,
        batch_size: int | None = None,
    ) -> int:
        """Estimate total memory for full computation.

        Args:
            M: Number of perturbation samples.
            N: Baseline samples per perturbation.
            T: Integration steps.
            batch_size: Batch size for forward passes. If None, uses M*N*T.

        Returns:
            Total memory requirement in bytes.
        """
        if batch_size is None:
            batch_size = M * N * T

        # Persistent tensors
        persistent = (
            self.estimate_perturbation_memory(M)
            + self.estimate_attribution_memory(M)
            + self.estimate_moment_matrix_memory()
        )

        # Per-batch tensors
        per_batch = self.estimate_batch_forward_memory(batch_size, N, T)

        return persistent + per_batch

    def compute_optimal_batch_size(
        self,
        M: int,
        N: int = 20,
        T: int = 50,
        device: torch.device | int | None = None,
    ) -> int:
        """Compute optimal batch size for available memory.

        Args:
            M: Number of perturbation samples.
            N: Baseline samples per perturbation.
            T: Integration steps.
            device: GPU device to query.

        Returns:
            Optimal batch size that fits in available memory.
        """
        mem_info = get_available_memory(device)
        if mem_info.total == 0:
            # CPU: return conservative default
            return min(1024, M * N * T)

        # Available memory with safety margin
        usable_memory = int(mem_info.available * (1 - self.safety_margin))

        # Subtract persistent memory requirements
        persistent = (
            self.estimate_perturbation_memory(M)
            + self.estimate_attribution_memory(M)
            + self.estimate_moment_matrix_memory()
        )

        remaining = usable_memory - persistent
        if remaining <= 0:
            # Not enough memory even for persistent data
            return 1

        # Estimate memory per batch element
        # per_element = K * bytes (input) + bytes (output) + 2*K*bytes (grads+overhead)
        per_element = (3 * self.K + 1) * self.bytes_per_element * 3  # 3x for safety

        optimal_batch = remaining // per_element
        total_points = M * N * T

        return min(max(1, optimal_batch), total_points)


def estimate_batch_size(
    K: int,
    M: int,
    N: int = 20,
    T: int = 50,
    device: torch.device | int | None = None,
    dtype: torch.dtype = torch.float32,
    safety_margin: float = 0.1,
) -> int:
    """Convenience function to estimate optimal batch size.

    Args:
        K: Number of feature channels.
        M: Number of perturbation samples.
        N: Baseline samples per perturbation.
        T: Integration steps.
        device: GPU device to query.
        dtype: Tensor dtype.
        safety_margin: Fraction of memory to leave unused.

    Returns:
        Optimal batch size.
    """
    estimator = MemoryEstimator(K, dtype, safety_margin)
    return estimator.compute_optimal_batch_size(M, N, T, device)


def get_optimal_batch_size(
    K: int,
    target_memory_gb: float = 8.0,
    N: int = 20,
    T: int = 50,
    dtype: torch.dtype = torch.float32,
) -> int:
    """Compute batch size for a target memory usage.

    Useful for specifying explicit memory limits rather than auto-detecting.

    Args:
        K: Number of feature channels.
        target_memory_gb: Target memory usage in GB.
        N: Baseline samples per perturbation.
        T: Integration steps.
        dtype: Tensor dtype.

    Returns:
        Batch size that fits in target memory.
    """
    bytes_per_element = torch.tensor([], dtype=dtype).element_size()
    target_bytes = int(target_memory_gb * (1024**3))

    # Memory per batch element (rough estimate)
    per_element = (3 * K + 1) * bytes_per_element * 3

    return max(1, target_bytes // per_element)


def clear_memory_cache(device: torch.device | int | None = None) -> None:
    """Clear CUDA memory cache.

    This can help reclaim memory between computations.

    Args:
        device: GPU device to clear. If None, clears all devices.
    """
    if torch.cuda.is_available():
        if device is not None:
            with torch.cuda.device(device):
                torch.cuda.empty_cache()
        else:
            torch.cuda.empty_cache()


def synchronize_device(device: torch.device | int | None = None) -> None:
    """Synchronize CUDA device.

    Ensures all pending operations are complete before continuing.

    Args:
        device: GPU device to synchronize. If None, syncs current device.
    """
    if torch.cuda.is_available():
        if device is not None:
            torch.cuda.synchronize(device)
        else:
            torch.cuda.synchronize()
