"""Residual norm metric for solution quality assessment.

This module provides the ResidualNorm metric that measures how well
the computed optimal weights α satisfy the linear system M_I @ α = b.

Mathematical Specification:
    r = ||M_I @ α - b|| / ||b||

A low residual norm indicates that the solution is accurate.
High residual norms suggest numerical issues or regularization effects.

Example:
    >>> from expected_gradcam.metrics.solver import ResidualNorm
    >>>
    >>> metric = ResidualNorm()
    >>> residual = metric.compute(M_I=M_I, alpha=alpha, b=b)
    >>> print(f"Residual norm: {residual:.6f}")
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from expected_gradcam.metrics.base import BaseMetric, no_grad, timed
from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.registry import register_metric

if TYPE_CHECKING:
    from torch import Tensor


@register_metric(
    "residual_norm",
    display_name="Residual Norm",
    lower_is_better=True,
    streamable=True,
    category="solver",
)
class ResidualNorm(BaseMetric):
    """Residual norm of the linear system solution.

    Computes the relative residual ||M_I @ α - b|| / ||b||
    which measures how accurately the solution satisfies the
    linear system.

    Lower values indicate better solution quality. Values > 0.1
    may indicate numerical issues or strong regularization effects.

    Example:
        >>> metric = ResidualNorm()
        >>> residual = metric.compute(M_I=M_I, alpha=alpha, b=b)
        >>> if residual > 0.1:
        ...     print("Warning: High residual - check solver settings")
    """

    def validate_inputs(
        self,
        M_I: "Tensor | None" = None,
        alpha: "Tensor | None" = None,
        b: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs for residual norm computation.

        Args:
            M_I: Second moment matrix [K, K].
            alpha: Optimal weights [K].
            b: Right-hand side vector [K].

        Raises:
            InvalidMetricInputError: If inputs are missing or invalid.
        """
        if M_I is None:
            raise InvalidMetricInputError(
                "residual_norm",
                "M_I",
                "Tensor [K, K]",
                "None",
            )

        if alpha is None:
            raise InvalidMetricInputError(
                "residual_norm",
                "alpha",
                "Tensor [K]",
                "None",
            )

        if b is None:
            raise InvalidMetricInputError(
                "residual_norm",
                "b",
                "Tensor [K]",
                "None",
            )

        K = M_I.shape[0]

        if M_I.ndim != 2 or M_I.shape[1] != K:
            raise InvalidMetricInputError(
                "residual_norm",
                "M_I",
                f"square matrix [{K}, {K}]",
                f"shape {list(M_I.shape)}",
            )

        if alpha.shape[0] != K:
            raise InvalidMetricInputError(
                "residual_norm",
                "alpha",
                f"shape [{K}]",
                f"shape {list(alpha.shape)}",
            )

        if b.shape[0] != K:
            raise InvalidMetricInputError(
                "residual_norm",
                "b",
                f"shape [{K}]",
                f"shape {list(b.shape)}",
            )

    @no_grad
    @timed
    def compute(
        self,
        M_I: "Tensor",
        alpha: "Tensor",
        b: "Tensor",
        eps: float = 1e-10,
        **kwargs,
    ) -> float:
        """Compute the relative residual norm.

        Args:
            M_I: Second moment matrix [K, K].
            alpha: Optimal weights [K].
            b: Right-hand side vector [K].
            eps: Small constant to avoid division by zero.

        Returns:
            Relative residual ||M_I @ α - b|| / ||b||.
        """
        self.validate_inputs(M_I=M_I, alpha=alpha, b=b)

        # Compute residual r = M_I @ α - b
        residual = torch.mv(M_I, alpha) - b

        # Relative norm
        residual_norm = torch.norm(residual)
        b_norm = torch.norm(b)

        if b_norm < eps:
            # If b is near zero, return absolute residual
            return float(residual_norm.item())

        return float((residual_norm / b_norm).item())


@register_metric(
    "absolute_residual",
    display_name="Absolute Residual",
    lower_is_better=True,
    streamable=True,
    category="solver",
)
class AbsoluteResidual(BaseMetric):
    """Absolute residual norm of the linear system solution.

    Computes ||M_I @ α - b|| without normalization by ||b||.
    Useful when b might be near zero.

    Example:
        >>> metric = AbsoluteResidual()
        >>> residual = metric.compute(M_I=M_I, alpha=alpha, b=b)
    """

    def validate_inputs(
        self,
        M_I: "Tensor | None" = None,
        alpha: "Tensor | None" = None,
        b: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs."""
        if M_I is None or alpha is None or b is None:
            raise InvalidMetricInputError(
                "absolute_residual",
                "M_I, alpha, b",
                "Tensor",
                "None",
            )

    @no_grad
    @timed
    def compute(
        self,
        M_I: "Tensor",
        alpha: "Tensor",
        b: "Tensor",
        **kwargs,
    ) -> float:
        """Compute the absolute residual norm.

        Args:
            M_I: Second moment matrix [K, K].
            alpha: Optimal weights [K].
            b: Right-hand side vector [K].

        Returns:
            Absolute residual ||M_I @ α - b||.
        """
        self.validate_inputs(M_I=M_I, alpha=alpha, b=b)

        residual = torch.mv(M_I, alpha) - b
        return float(torch.norm(residual).item())


def compute_residual_stats(
    M_I: "Tensor",
    alpha: "Tensor",
    b: "Tensor",
) -> dict[str, float]:
    """Compute comprehensive residual statistics.

    Args:
        M_I: Second moment matrix [K, K].
        alpha: Optimal weights [K].
        b: Right-hand side vector [K].

    Returns:
        Dictionary with:
            - absolute_residual: ||M_I @ α - b||
            - relative_residual: ||M_I @ α - b|| / ||b||
            - max_component_error: max|r_i|
            - mean_component_error: mean|r_i|

    Example:
        >>> stats = compute_residual_stats(M_I, alpha, b)
        >>> print(f"Relative residual: {stats['relative_residual']:.2e}")
    """
    residual = torch.mv(M_I, alpha) - b

    abs_residual = torch.norm(residual)
    b_norm = torch.norm(b)
    rel_residual = abs_residual / max(b_norm, 1e-10)

    return {
        "absolute_residual": float(abs_residual.item()),
        "relative_residual": float(rel_residual.item()),
        "max_component_error": float(residual.abs().max().item()),
        "mean_component_error": float(residual.abs().mean().item()),
    }
