"""Tensor type aliases for documentation and type checking.

These type aliases document the expected shapes and semantics of tensors
used throughout the Expected GradCAM implementation. While Python's type
system cannot enforce tensor shapes at runtime, these aliases serve as
documentation and can be used with tools like TorchTyping for shape checking.

Notation:
    - N: Batch size (typically 1)
    - K: Number of feature channels
    - U, V: Feature map spatial dimensions (height, width)
    - H, W: Input image spatial dimensions
    - M: Number of perturbation samples
    - T: Number of integration steps
    - S: Number of segments
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Annotated

import torch


if TYPE_CHECKING:
    from torch import Tensor


# =============================================================================
# Feature Maps and Activations
# =============================================================================

FeatureMaps = Annotated["Tensor", "Shape: [N, K, U, V] - Feature maps from target layer"]
"""Feature maps extracted from the target layer.

Shape: [N, K, U, V] where:
    - N: Batch size (typically 1)
    - K: Number of channels (e.g., 2048 for ResNet-50 layer4)
    - U: Feature map height (e.g., 7 for 224x224 input)
    - V: Feature map width (e.g., 7 for 224x224 input)

These are the activations A^k used in the predictor function:
    g(z; A) = y^c(z_1 * A^1, ..., z_K * A^K)
"""


# =============================================================================
# Perturbations and Baselines
# =============================================================================

Perturbation = Annotated["Tensor", "Shape: [K] - Single perturbation vector I"]
"""Single perturbation vector in feature-multiplier space.

Shape: [K] where K is the number of feature channels.

The perturbation I represents the direction from reference point z_0
to a baseline point: I = z_0 - z' = z_0 - α * h(x')

For data-aware perturbations:
    h(x') = GAP(A') / GAP(A)  (normalized feature activation)
"""

PerturbationBatch = Annotated["Tensor", "Shape: [M, K] - Batch of M perturbations"]
"""Batch of perturbation vectors.

Shape: [M, K] where:
    - M: Number of perturbation samples
    - K: Number of feature channels

Used in batched computation of the second moment matrix M_I = E[I * I^T].
"""


# =============================================================================
# Attributions (Integrated/Expected Gradients)
# =============================================================================

Attribution = Annotated["Tensor", "Shape: [K] - Single attribution vector φ"]
"""Single attribution vector from Integrated/Expected Gradients.

Shape: [K] where K is the number of feature channels.

Computed via path integration:
    φ^{IG} = ∫_0^1 ∇_z g(z' + t * I) dt

For Expected Gradients, this is averaged over baselines:
    φ^{EG} = E_{z' ~ D}[φ^{IG}]
"""

AttributionBatch = Annotated["Tensor", "Shape: [M, K] - Batch of M attributions"]
"""Batch of attribution vectors.

Shape: [M, K] where:
    - M: Number of perturbation samples
    - K: Number of feature channels

Each row φ_m corresponds to the attribution for perturbation I_m.
"""


# =============================================================================
# Second Moment Matrix and Weights
# =============================================================================

SecondMomentMatrix = Annotated["Tensor", "Shape: [K, K] - Second moment matrix M_I"]
"""Second moment matrix of perturbations.

Shape: [K, K] where K is the number of feature channels.

Computed as: M_I = E[I * I^T] = (1/M) * Σ_m I_m * I_m^T

This matrix appears in the optimal weights formula:
    α* = M_I^{-1} * E[I * <I, φ>]
"""

Weights = Annotated["Tensor", "Shape: [K] - Optimal feature map weights α*"]
"""Optimal weights for feature map combination.

Shape: [K] where K is the number of feature channels.

These weights minimize explanation infidelity and are computed as:
    α* = M_I^{-1} * b
    where b = E[I * <I, φ>]

The final heatmap is: L^c = ReLU(Σ_k α*_k * A^k)
"""


# =============================================================================
# Heatmaps
# =============================================================================

Heatmap = Annotated["Tensor", "Shape: [H, W] - Final heatmap at input resolution"]
"""Final attribution heatmap at input image resolution.

Shape: [H, W] where H, W are the input image dimensions.

Values are normalized to [0, 1] range using the specified normalization method.
"""

CoarseHeatmap = Annotated["Tensor", "Shape: [U, V] - Heatmap at feature resolution"]
"""Attribution heatmap at feature map resolution (before upsampling).

Shape: [U, V] where U, V are the feature map dimensions.

This is the raw output: L^c = ReLU(Σ_k α*_k * A^k) before bilinear upsampling.
"""


# =============================================================================
# Segment-Constrained Types
# =============================================================================

SegmentMasks = Annotated["Tensor", "Shape: [S, H, W] - Binary segment masks"]
"""Binary masks for S segments from SAM segmentation.

Shape: [S, H, W] where:
    - S: Number of segments
    - H, W: Input image dimensions

Each mask[s] is a binary tensor where 1 indicates pixels belonging to segment s.
For soft boundaries, values may be in [0, 1] after Gaussian smoothing.
"""


# =============================================================================
# Type aliases for common patterns
# =============================================================================

# Device type for tensor placement
Device = str | torch.device
"""Device specification (e.g., "cpu", "cuda", "cuda:0", torch.device("cuda"))."""

# Numeric types for configuration
Scalar = int | float
"""Numeric scalar value for configuration parameters."""

# Shape tuple
Shape = tuple[int, ...]
"""Tensor shape as a tuple of integers."""
