"""Condition number and related metrics for linear system analysis.

This module provides metrics for analyzing the numerical properties
of the second moment matrix M_I = E[I I^T], which is critical for
the stability of the optimal weights computation.

Metrics:
- ConditionNumber: κ(M_I) = λ_max / λ_min
- EffectiveRank: Number of significant eigenvalues
- EigenvalueSpread: log10(λ_max / λ_min)

Mathematical Background:
    The condition number indicates how sensitive the solution α* is
    to perturbations in M_I or b. High condition numbers (> 1e6)
    suggest numerical instability.

Example:
    >>> from expected_gradcam.metrics.solver import ConditionNumber
    >>>
    >>> metric = ConditionNumber()
    >>> cond = metric.compute(M_I=second_moment_matrix)
    >>> print(f"Condition number: {cond:.2e}")
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from expected_gradcam.metrics.base import BaseMetric, no_grad, timed
from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.registry import register_metric

if TYPE_CHECKING:
    from torch import Tensor


@register_metric(
    "condition_number",
    display_name="Condition Number",
    lower_is_better=True,
    streamable=True,
    category="solver",
)
class ConditionNumber(BaseMetric):
    """Condition number of the second moment matrix M_I.

    Computes κ(M_I) = λ_max / λ_min where λ are eigenvalues.

    A well-conditioned matrix has condition number < 1e6.
    High condition numbers indicate numerical instability and
    suggest using regularization (solver_method='adaptive_reg').

    Attributes:
        threshold: Minimum eigenvalue threshold for rank determination.

    Example:
        >>> metric = ConditionNumber()
        >>> cond = metric.compute(M_I=M_I)
        >>> if cond > 1e6:
        ...     print("Warning: Ill-conditioned matrix")
    """

    def __init__(self, threshold: float = 1e-10) -> None:
        """Initialize the metric.

        Args:
            threshold: Minimum eigenvalue threshold for positive eigenvalues.
        """
        self.threshold = threshold

    def validate_inputs(
        self,
        M_I: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs for condition number computation.

        Args:
            M_I: Second moment matrix [K, K].

        Raises:
            InvalidMetricInputError: If M_I is missing or invalid.
        """
        if M_I is None:
            raise InvalidMetricInputError(
                "condition_number",
                "M_I",
                "Tensor [K, K]",
                "None",
                suggestion="Pass the second moment matrix M_I = E[I I^T]",
            )

        if M_I.ndim != 2:
            raise InvalidMetricInputError(
                "condition_number",
                "M_I",
                "2D tensor [K, K]",
                f"{M_I.ndim}D tensor",
            )

        if M_I.shape[0] != M_I.shape[1]:
            raise InvalidMetricInputError(
                "condition_number",
                "M_I",
                "square matrix [K, K]",
                f"shape {list(M_I.shape)}",
            )

    @no_grad
    @timed
    def compute(
        self,
        M_I: "Tensor",
        threshold: float | None = None,
        **kwargs,
    ) -> float:
        """Compute the condition number of M_I.

        Args:
            M_I: Second moment matrix [K, K] (symmetric, positive semi-definite).
            threshold: Override default eigenvalue threshold.

        Returns:
            Condition number κ = λ_max / λ_min (inf if singular).
        """
        self.validate_inputs(M_I=M_I)
        threshold = threshold or self.threshold

        # Compute eigenvalues (M_I should be symmetric)
        eigenvalues = torch.linalg.eigvalsh(M_I)

        max_ev = eigenvalues.max()

        # Find positive eigenvalues above threshold
        abs_threshold = max(threshold, threshold * max_ev.abs())
        pos_mask = eigenvalues > abs_threshold

        if not pos_mask.any():
            return float("inf")

        min_pos_ev = eigenvalues[pos_mask].min()
        return float((max_ev / min_pos_ev).item())


@register_metric(
    "effective_rank",
    display_name="Effective Rank",
    lower_is_better=False,  # Higher rank is generally better (more information)
    streamable=True,
    category="solver",
)
class EffectiveRank(BaseMetric):
    """Effective rank of the second moment matrix M_I.

    Counts the number of eigenvalues above a relative threshold.
    Ideally equals K (full rank), but data-aware perturbations
    may produce lower effective rank.

    Attributes:
        threshold: Relative threshold (eigenvalue > threshold * max_eigenvalue).

    Example:
        >>> metric = EffectiveRank()
        >>> rank = metric.compute(M_I=M_I)
        >>> print(f"Effective rank: {rank}/{K}")
    """

    def __init__(self, threshold: float = 1e-6) -> None:
        """Initialize the metric.

        Args:
            threshold: Relative threshold for counting significant eigenvalues.
        """
        self.threshold = threshold

    def validate_inputs(
        self,
        M_I: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs."""
        if M_I is None:
            raise InvalidMetricInputError(
                "effective_rank",
                "M_I",
                "Tensor [K, K]",
                "None",
            )

        if M_I.ndim != 2 or M_I.shape[0] != M_I.shape[1]:
            raise InvalidMetricInputError(
                "effective_rank",
                "M_I",
                "square matrix [K, K]",
                f"shape {list(M_I.shape)}",
            )

    @no_grad
    @timed
    def compute(
        self,
        M_I: "Tensor",
        threshold: float | None = None,
        **kwargs,
    ) -> float:
        """Compute the effective rank of M_I.

        Args:
            M_I: Second moment matrix [K, K].
            threshold: Override default relative threshold.

        Returns:
            Number of eigenvalues above the threshold (as float for consistency).
        """
        self.validate_inputs(M_I=M_I)
        threshold = threshold or self.threshold

        eigenvalues = torch.linalg.eigvalsh(M_I)
        max_ev = eigenvalues.max()

        # Count eigenvalues above relative threshold
        significant = eigenvalues > (threshold * max_ev)
        return float(significant.sum().item())


@register_metric(
    "eigenvalue_spread",
    display_name="Eigenvalue Spread",
    lower_is_better=True,  # Lower spread means better conditioning
    streamable=True,
    category="solver",
)
class EigenvalueSpread(BaseMetric):
    """Eigenvalue spread (log scale) of M_I.

    Computes log10(λ_max / λ_min) which gives a more interpretable
    measure of the eigenvalue range than the raw condition number.

    A spread of 3 means eigenvalues span 3 orders of magnitude.

    Example:
        >>> metric = EigenvalueSpread()
        >>> spread = metric.compute(M_I=M_I)
        >>> print(f"Eigenvalues span {spread:.1f} orders of magnitude")
    """

    def __init__(self, threshold: float = 1e-10) -> None:
        """Initialize the metric.

        Args:
            threshold: Minimum eigenvalue threshold.
        """
        self.threshold = threshold

    def validate_inputs(
        self,
        M_I: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs."""
        if M_I is None:
            raise InvalidMetricInputError(
                "eigenvalue_spread",
                "M_I",
                "Tensor [K, K]",
                "None",
            )

    @no_grad
    @timed
    def compute(
        self,
        M_I: "Tensor",
        threshold: float | None = None,
        **kwargs,
    ) -> float:
        """Compute the eigenvalue spread.

        Args:
            M_I: Second moment matrix [K, K].
            threshold: Override default eigenvalue threshold.

        Returns:
            log10(λ_max / λ_min) or inf if singular.
        """
        self.validate_inputs(M_I=M_I)
        threshold = threshold or self.threshold

        eigenvalues = torch.linalg.eigvalsh(M_I)
        pos_mask = eigenvalues > threshold

        if not pos_mask.any():
            return float("inf")

        pos_eigs = eigenvalues[pos_mask]
        ratio = pos_eigs.max() / pos_eigs.min()
        return float(torch.log10(ratio).item())


def analyze_condition(
    M_I: "Tensor",
    threshold: float = 1e-10,
) -> dict[str, float]:
    """Analyze the conditioning of M_I comprehensively.

    Convenience function that computes all condition-related metrics
    in a single call.

    Args:
        M_I: Second moment matrix [K, K].
        threshold: Eigenvalue threshold.

    Returns:
        Dictionary with:
            - condition_number: κ(M_I)
            - effective_rank: Number of significant eigenvalues
            - eigenvalue_spread: log10(max/min)
            - eigenvalue_min: Smallest positive eigenvalue
            - eigenvalue_max: Largest eigenvalue
            - is_well_conditioned: condition_number < 1e6

    Example:
        >>> analysis = analyze_condition(M_I)
        >>> if not analysis["is_well_conditioned"]:
        ...     print(f"Warning: κ = {analysis['condition_number']:.2e}")
    """
    eigenvalues = torch.linalg.eigvalsh(M_I)
    max_ev = eigenvalues.max()

    # Find positive eigenvalues
    pos_mask = eigenvalues > threshold
    pos_eigs = eigenvalues[pos_mask] if pos_mask.any() else torch.tensor([0.0])
    min_pos_ev = pos_eigs.min() if len(pos_eigs) > 0 else torch.tensor(0.0)

    # Compute metrics
    condition_number = float((max_ev / min_pos_ev).item()) if min_pos_ev > 0 else float("inf")
    effective_rank = float(pos_mask.sum().item())
    eigenvalue_spread = (
        float(torch.log10(max_ev / min_pos_ev).item()) if min_pos_ev > 0 else float("inf")
    )

    return {
        "condition_number": condition_number,
        "effective_rank": effective_rank,
        "eigenvalue_spread": eigenvalue_spread,
        "eigenvalue_min": float(min_pos_ev.item()),
        "eigenvalue_max": float(max_ev.item()),
        "is_well_conditioned": condition_number < 1e6,
    }
