"""Second moment matrix computation for Expected GradCAM.

The second moment matrix M_I = E[I I^T] captures the covariance structure
of perturbations and is used to compute optimal weights.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch


if TYPE_CHECKING:
    from torch import Tensor


def compute_second_moment_matrix(I_samples: "Tensor") -> "Tensor":
    """Compute the second moment matrix M_I = E[I I^T].

    Mathematical specification:
        M_I = E_{I ~ μ_I}[I I^T] = (1/M) Σ_{m=1}^M I^{(m)} (I^{(m)})^T

    Component-wise: (M_I)_{ij} = E[I_i · I_j]

    Args:
        I_samples: Perturbation samples [M, K] where M is number of samples
            and K is the dimensionality (number of feature maps).

    Returns:
        Second moment matrix M_I [K, K].

    Note:
        The returned matrix is positive semi-definite.
        May be singular if M < K or samples are degenerate.
    """
    M = I_samples.shape[0]

    # Efficient computation using einsum:
    # M_I[i,j] = (1/M) Σ_m I_i^(m) * I_j^(m)
    M_I = torch.einsum("mk,mj->kj", I_samples, I_samples) / M

    return M_I


def compute_second_moment_matrix_stable(
    I_samples: "Tensor",
    regularization_eps: float = 1e-6,
) -> "Tensor":
    """Compute regularized second moment matrix for numerical stability.

    Adds regularization: M_I_reg = M_I + ε * I

    This ensures the matrix is invertible even when:
    - Number of samples M < K
    - Samples are nearly colinear
    - Some feature channels are constant

    Args:
        I_samples: Perturbation samples [M, K].
        regularization_eps: Regularization constant ε.

    Returns:
        Regularized second moment matrix [K, K].
    """
    M_I = compute_second_moment_matrix(I_samples)
    K = M_I.shape[0]
    device = M_I.device

    M_I_reg = M_I + regularization_eps * torch.eye(K, device=device)

    return M_I_reg


def analyze_second_moment_matrix(
    M_I: "Tensor",
    threshold: float = 1e-10,
) -> tuple[bool, float, int, "Tensor"]:
    """Analyze properties of the second moment matrix.

    Args:
        M_I: Second moment matrix [K, K].
        threshold: Threshold for considering eigenvalue as zero.

    Returns:
        Tuple of:
        - is_invertible: Whether matrix appears invertible
        - condition_number: Condition number (max_eig / min_eig)
        - rank: Numerical rank (number of eigenvalues above threshold)
        - eigenvalues: All eigenvalues [K]
    """
    # Compute eigenvalues (M_I is symmetric positive semi-definite)
    eigenvalues = torch.linalg.eigvalsh(M_I)

    # Count non-zero eigenvalues
    K = M_I.shape[0]
    max_eig = eigenvalues.max()
    abs_threshold = max(threshold, threshold * max_eig.abs())
    rank = int((eigenvalues.abs() > abs_threshold).sum().item())

    # Compute condition number
    pos_mask = eigenvalues > abs_threshold
    if pos_mask.any():
        min_eig_nonzero = eigenvalues[pos_mask].min()
        condition_number = float((max_eig / min_eig_nonzero).item())
    else:
        condition_number = float("inf")

    # Check invertibility
    is_invertible = rank == K and condition_number < 1e12

    return is_invertible, condition_number, rank, eigenvalues


def compute_cross_moment(
    I_samples: "Tensor",
    phi_samples: "Tensor",
) -> "Tensor":
    """Compute the cross-moment b = E[I * <I, φ>].

    This is needed for optimal weights: α* = M_I^{-1} * b

    Mathematical breakdown:
        b = E[I * <I, φ>] = (1/M) Σ_m I^{(m)} * <I^{(m)}, φ^{(m)}>

    where <I, φ> = Σ_k I_k * φ_k is the inner product.

    Args:
        I_samples: Perturbation samples [M, K].
        phi_samples: Attribution vectors [M, K].

    Returns:
        Cross-moment vector b [K].
    """
    # Step 1: Compute inner products <I^(m), φ^(m)> for each sample
    inner_products = (I_samples * phi_samples).sum(dim=1)  # [M]

    # Step 2: Scale perturbations by inner products
    scaled_I = I_samples * inner_products.unsqueeze(1)  # [M, K]

    # Step 3: Average over samples
    b = scaled_I.mean(dim=0)  # [K]

    return b


class IncrementalMomentComputer:
    """Incrementally compute M_I and b without storing all samples.

    Enables streaming computation of optimal weights with O(K² + K) memory
    instead of O(M×K) for storing all phi_samples.

    Usage:
        >>> computer = IncrementalMomentComputer(K, device)
        >>> for I_chunk, phi_chunk in data_stream:
        ...     computer.update(I_chunk, phi_chunk)
        >>> M_I, b = computer.finalize()

    Attributes:
        K: Feature dimension.
        device: Torch device.
        num_samples: Number of samples accumulated.
    """

    def __init__(self, K: int, device: torch.device) -> None:
        """Initialize incremental moment computer.

        Args:
            K: Dimension of feature space.
            device: Torch device for computations.
        """
        self.K = K
        self.device = device
        self._M_I_sum = torch.zeros(K, K, device=device)
        self._b_sum = torch.zeros(K, device=device)
        self._n_samples = 0

    @property
    def num_samples(self) -> int:
        """Number of samples accumulated."""
        return self._n_samples

    def update(self, I_batch: "Tensor", phi_batch: "Tensor") -> None:
        """Update with a batch of perturbations and attributions.

        Args:
            I_batch: Perturbation samples [batch_size, K].
            phi_batch: Attribution vectors [batch_size, K].
        """
        batch_size = I_batch.shape[0]

        # Incremental M_I: M_I_sum += I^T @ I
        self._M_I_sum += torch.mm(I_batch.T, I_batch)

        # Incremental b: b_sum += Σ I^(m) * <I^(m), φ^(m)>
        inner_prods = (I_batch * phi_batch).sum(dim=1)  # [batch_size]
        self._b_sum += (I_batch * inner_prods.unsqueeze(1)).sum(dim=0)

        self._n_samples += batch_size

    def finalize(self) -> tuple["Tensor", "Tensor"]:
        """Finalize and return normalized M_I and b.

        Returns:
            Tuple of (M_I [K, K], b [K]).

        Raises:
            ValueError: If no samples have been added.
        """
        if self._n_samples == 0:
            raise ValueError("No samples have been added. Call update() first.")

        M_I = self._M_I_sum / self._n_samples
        b = self._b_sum / self._n_samples
        return M_I, b

    def get_current_M_I(self) -> "Tensor":
        """Get current (partial) M_I for convergence checking.

        Returns:
            Current M_I estimate [K, K].
        """
        if self._n_samples == 0:
            return torch.zeros(self.K, self.K, device=self.device)
        return self._M_I_sum / self._n_samples

    def get_current_b(self) -> "Tensor":
        """Get current (partial) b for convergence checking.

        Returns:
            Current b estimate [K].
        """
        if self._n_samples == 0:
            return torch.zeros(self.K, device=self.device)
        return self._b_sum / self._n_samples

    def reset(self) -> None:
        """Reset the accumulator for reuse."""
        self._M_I_sum.zero_()
        self._b_sum.zero_()
        self._n_samples = 0
