"""Utility functions for sampling in Expected GradCAM.

This module provides helper functions for centering samples, verifying
centering conditions, and computing optimal scales for baseline sampling.

These utilities support the mathematical requirements of Expected GradCAM:
- Centering: E[z'] = 0 is required for the completeness axiom
- Scaling: Appropriate perturbation scale ensures numerical stability
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    from torch import Tensor


def center_samples(samples: Tensor) -> Tensor:
    """Center samples to have zero mean.

    For Expected Gradients to satisfy the completeness axiom, baseline
    samples must be centered: E[z'] = 0. This function explicitly
    centers samples by subtracting the mean.

    Args:
        samples: Input samples [N, K] or [M, K].

    Returns:
        Centered samples [N, K] or [M, K] with mean = 0 along dim 0.

    Example:
        >>> samples = torch.randn(100, 512)
        >>> centered = center_samples(samples)
        >>> assert centered.mean(dim=0).abs().max() < 1e-6
    """
    return samples - samples.mean(dim=0, keepdim=True)


def verify_centered(samples: Tensor, tolerance: float = 1e-6) -> bool:
    """Verify that samples are centered (mean ≈ 0).

    Checks whether the mean of samples along dimension 0 is within
    the specified tolerance of zero.

    Args:
        samples: Samples to verify [N, K].
        tolerance: Maximum allowed absolute mean value.

    Returns:
        True if all dimensions have mean within tolerance of zero.

    Example:
        >>> centered = center_samples(torch.randn(100, 512))
        >>> assert verify_centered(centered)
    """
    mean = samples.mean(dim=0)
    return bool((mean.abs() < tolerance).all().item())


def compute_optimal_baseline_scale(
    feature_maps: Tensor,
    percentile: float = 0.1,
) -> float:
    """Compute an appropriate scale for baseline sampling based on feature maps.

    The scale should be proportional to the typical magnitude of feature
    map variations to ensure meaningful perturbations without causing
    numerical instability.

    Args:
        feature_maps: Feature maps [B, K, U, V].
        percentile: Fraction of GAP std to use as scale.

    Returns:
        Recommended scale value for baseline sampling.

    Example:
        >>> features = torch.randn(1, 2048, 7, 7)
        >>> scale = compute_optimal_baseline_scale(features)
        >>> print(f"Recommended scale: {scale:.4f}")
    """
    # Global average pooling
    gap = feature_maps.mean(dim=(2, 3))  # [B, K]

    # Use std of GAP values as reference
    gap_std = gap.std()

    # Scale should be a fraction to ensure perturbations are meaningful
    # but not too large (which could cause numerical issues)
    scale = float(gap_std.item() * percentile)

    return max(scale, 1e-6)  # Ensure non-zero


def normalize_perturbations(
    samples: Tensor,
    target_scale: float = 0.3,
    center: bool = True,
) -> Tensor:
    """Normalize perturbation samples to target scale.

    Centers samples and scales to a target standard deviation for
    numerical stability.

    Args:
        samples: Perturbation samples [M, K].
        target_scale: Target standard deviation.
        center: Whether to center samples first.

    Returns:
        Normalized samples with std ≈ target_scale.

    Example:
        >>> perturbs = torch.randn(100, 512) * 10  # Large scale
        >>> normalized = normalize_perturbations(perturbs, target_scale=0.3)
        >>> print(f"Std after: {normalized.std():.4f}")  # ~0.3
    """
    if center:
        samples = samples - samples.mean(dim=0, keepdim=True)

    current_std = samples.std()
    if current_std > 1e-10:
        samples = samples * (target_scale / current_std)

    return samples


def safe_divide(
    numerator: Tensor,
    denominator: Tensor,
    min_threshold: float | None = None,
) -> Tensor:
    """Perform division with protection against division by small values.

    Uses a robust threshold based on the median of absolute values to
    prevent division by near-zero values while preserving the sign.

    Args:
        numerator: Numerator tensor.
        denominator: Denominator tensor (same shape as numerator).
        min_threshold: Minimum absolute value for denominator.
            If None, computed as 1% of median absolute value.

    Returns:
        Result of safe division.

    Example:
        >>> num = torch.randn(512)
        >>> denom = torch.randn(512) * 0.001  # Small values
        >>> result = safe_divide(num, denom)  # No division by zero
    """
    denom_abs = denominator.abs()

    if min_threshold is None:
        median = denom_abs.median()
        min_threshold = max(float(median.item()) * 0.01, 1e-6)

    # Create safe denominator preserving sign
    denom_safe = denom_abs.clamp(min=min_threshold) * denominator.sign()

    # Handle exact zeros (where sign is 0)
    denom_safe = torch.where(
        denom_abs < 1e-10,
        torch.full_like(denominator, min_threshold),
        denom_safe,
    )

    return numerator / denom_safe


__all__ = [
    "center_samples",
    "verify_centered",
    "compute_optimal_baseline_scale",
    "normalize_perturbations",
    "safe_divide",
]
