"""Weight norm metric for optimal weights analysis.

This module provides metrics for analyzing the optimal weights α
computed by the Expected GradCAM algorithm.

Example:
    >>> from expected_gradcam.metrics.heatmap import WeightNorm
    >>>
    >>> metric = WeightNorm()
    >>> norm = metric.compute(alpha=alpha)
    >>> print(f"Weight norm: {norm:.4f}")
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from expected_gradcam.metrics.base import BaseMetric, no_grad, timed
from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.registry import register_metric

if TYPE_CHECKING:
    from torch import Tensor


@register_metric(
    "weight_norm",
    display_name="Weight Norm ||α||",
    lower_is_better=None,  # Neither - informational metric
    streamable=True,
    category="heatmap",
)
class WeightNorm(BaseMetric):
    """L2 norm of optimal weights.

    Computes ||α||_2 which indicates the overall magnitude of the
    optimal weights. Can be used as a regularization indicator.

    High weight norms may indicate overfitting to specific features,
    while very low norms might suggest the solution is near zero.

    Attributes:
        ord: Norm order (default 2 for L2 norm).

    Example:
        >>> metric = WeightNorm()
        >>> norm = metric.compute(alpha=alpha)
        >>> print(f"||α||₂ = {norm:.4f}")
    """

    # Neither lower nor higher is inherently better
    _lower_is_better = True  # But for consistency, we report lower as "stable"

    def __init__(self, ord: int | float = 2) -> None:
        """Initialize the metric.

        Args:
            ord: Norm order (1 for L1, 2 for L2, inf for max).
        """
        self.ord = ord

    def validate_inputs(
        self,
        alpha: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs.

        Args:
            alpha: Optimal weights [K].

        Raises:
            InvalidMetricInputError: If alpha is missing.
        """
        if alpha is None:
            raise InvalidMetricInputError(
                "weight_norm",
                "alpha",
                "Tensor [K]",
                "None",
            )

        if alpha.ndim != 1:
            raise InvalidMetricInputError(
                "weight_norm",
                "alpha",
                "1D tensor [K]",
                f"{alpha.ndim}D tensor",
            )

    @no_grad
    @timed
    def compute(
        self,
        alpha: "Tensor",
        ord: int | float | None = None,
        **kwargs,
    ) -> float:
        """Compute the weight norm.

        Args:
            alpha: Optimal weights [K].
            ord: Override default norm order.

        Returns:
            ||α||_p where p is the norm order.
        """
        self.validate_inputs(alpha=alpha)
        ord = ord or self.ord

        return float(torch.norm(alpha, p=ord).item())


@register_metric(
    "weight_sparsity",
    display_name="Weight Sparsity",
    lower_is_better=False,  # Higher sparsity = more focused attribution
    streamable=True,
    category="heatmap",
)
class WeightSparsity(BaseMetric):
    """Sparsity of optimal weights.

    Computes the fraction of weights below a threshold, indicating
    how focused the attribution is on a subset of feature channels.

    Higher sparsity means fewer channels contribute significantly.

    Attributes:
        threshold: Relative threshold for considering a weight "zero".

    Example:
        >>> metric = WeightSparsity()
        >>> sparsity = metric.compute(alpha=alpha)
        >>> print(f"Sparsity: {sparsity*100:.1f}% of weights near zero")
    """

    def __init__(self, threshold: float = 0.01) -> None:
        """Initialize the metric.

        Args:
            threshold: Relative threshold (weight < threshold * max_weight is "zero").
        """
        self.threshold = threshold

    def validate_inputs(
        self,
        alpha: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs."""
        if alpha is None:
            raise InvalidMetricInputError(
                "weight_sparsity",
                "alpha",
                "Tensor [K]",
                "None",
            )

    @no_grad
    @timed
    def compute(
        self,
        alpha: "Tensor",
        threshold: float | None = None,
        **kwargs,
    ) -> float:
        """Compute weight sparsity.

        Args:
            alpha: Optimal weights [K].
            threshold: Override default threshold.

        Returns:
            Fraction of weights below threshold (0 to 1).
        """
        self.validate_inputs(alpha=alpha)
        threshold = threshold or self.threshold

        max_weight = alpha.abs().max()
        if max_weight < 1e-10:
            return 1.0  # All weights are effectively zero

        sparse_count = (alpha.abs() < threshold * max_weight).sum()
        return float(sparse_count.item() / alpha.numel())


def compute_weight_stats(alpha: "Tensor") -> dict[str, float]:
    """Compute comprehensive weight statistics.

    Args:
        alpha: Optimal weights [K].

    Returns:
        Dictionary with:
            - l1_norm: ||α||₁
            - l2_norm: ||α||₂
            - linf_norm: ||α||∞
            - mean: Mean value
            - std: Standard deviation
            - min: Minimum value
            - max: Maximum value
            - positive_fraction: Fraction of positive weights

    Example:
        >>> stats = compute_weight_stats(alpha)
        >>> print(f"L2 norm: {stats['l2_norm']:.4f}")
    """
    return {
        "l1_norm": float(torch.norm(alpha, p=1).item()),
        "l2_norm": float(torch.norm(alpha, p=2).item()),
        "linf_norm": float(torch.norm(alpha, p=float("inf")).item()),
        "mean": float(alpha.mean().item()),
        "std": float(alpha.std().item()),
        "min": float(alpha.min().item()),
        "max": float(alpha.max().item()),
        "positive_fraction": float((alpha > 0).sum().item() / alpha.numel()),
    }
