"""Heatmap entropy metric for attribution spread analysis.

This module provides the HeatmapEntropy metric that measures
how spread out or focused the attribution heatmap is.

Mathematical Specification:
    H(heatmap) = -Σ p_i * log(p_i)

    where p_i = heatmap_i / Σ heatmap_j is the normalized distribution

Higher entropy indicates more spread attribution (less focused).
Lower entropy indicates more concentrated attribution (more focused).

Example:
    >>> from expected_gradcam.metrics.heatmap import HeatmapEntropy
    >>>
    >>> metric = HeatmapEntropy()
    >>> entropy = metric.compute(heatmap=heatmap)
    >>> print(f"Heatmap entropy: {entropy:.4f}")
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from expected_gradcam.metrics.base import BaseMetric, no_grad, timed
from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.registry import register_metric

if TYPE_CHECKING:
    from torch import Tensor


@register_metric(
    "heatmap_entropy",
    display_name="Heatmap Entropy",
    lower_is_better=None,  # Neither - depends on task
    streamable=True,
    category="heatmap",
)
class HeatmapEntropy(BaseMetric):
    """Entropy of the heatmap distribution.

    Measures the spread of attribution values. Higher entropy means
    more uniform distribution (less focused), while lower entropy
    means more concentrated attribution (more focused).

    The entropy is normalized by the maximum possible entropy
    (log of number of pixels) to give a value between 0 and 1.

    Attributes:
        normalize: Whether to normalize by maximum entropy.
        eps: Small constant to avoid log(0).

    Example:
        >>> metric = HeatmapEntropy()
        >>> entropy = metric.compute(heatmap=heatmap)
        >>> print(f"Normalized entropy: {entropy:.4f}")
    """

    # Neither lower nor higher is inherently better
    _lower_is_better = False  # Higher entropy = more spread = often more interpretable

    def __init__(self, normalize: bool = True, eps: float = 1e-10) -> None:
        """Initialize the metric.

        Args:
            normalize: Whether to normalize entropy to [0, 1] range.
            eps: Small constant to avoid log(0).
        """
        self.normalize = normalize
        self.eps = eps

    def validate_inputs(
        self,
        heatmap: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs.

        Args:
            heatmap: Attribution heatmap [H, W] or [B, H, W].

        Raises:
            InvalidMetricInputError: If heatmap is missing or invalid.
        """
        if heatmap is None:
            raise InvalidMetricInputError(
                "heatmap_entropy",
                "heatmap",
                "Tensor [H, W] or [B, H, W]",
                "None",
            )

        if heatmap.ndim < 2 or heatmap.ndim > 3:
            raise InvalidMetricInputError(
                "heatmap_entropy",
                "heatmap",
                "2D or 3D tensor",
                f"{heatmap.ndim}D tensor",
            )

    @no_grad
    @timed
    def compute(
        self,
        heatmap: "Tensor",
        normalize: bool | None = None,
        **kwargs,
    ) -> float:
        """Compute the heatmap entropy.

        Args:
            heatmap: Attribution heatmap [H, W] or [B, H, W].
            normalize: Override default normalization setting.

        Returns:
            Entropy value (optionally normalized to [0, 1]).
        """
        self.validate_inputs(heatmap=heatmap)
        normalize = normalize if normalize is not None else self.normalize

        # Handle batched input
        if heatmap.ndim == 3:
            heatmap = heatmap.squeeze(0)

        # Flatten and ensure positive values
        flat = heatmap.flatten()
        flat = torch.relu(flat)  # ReLU to handle negative values

        # Normalize to probability distribution
        total = flat.sum() + self.eps
        probs = flat / total

        # Compute entropy: -Σ p * log(p)
        log_probs = torch.log(probs + self.eps)
        entropy = -torch.sum(probs * log_probs)

        entropy_value = float(entropy.item())

        # Normalize by maximum entropy (uniform distribution)
        if normalize:
            max_entropy = torch.log(torch.tensor(float(flat.numel())))
            entropy_value = entropy_value / float(max_entropy.item())

        return entropy_value


@register_metric(
    "heatmap_gini",
    display_name="Heatmap Gini",
    lower_is_better=False,  # Higher Gini = more inequality = more focused
    streamable=True,
    category="heatmap",
)
class HeatmapGini(BaseMetric):
    """Gini coefficient of the heatmap distribution.

    Measures inequality in attribution values. Higher Gini means
    more concentrated attribution (more focused), while lower Gini
    means more uniform distribution.

    Gini = 0 means perfect equality (uniform)
    Gini = 1 means maximum inequality (single pixel has all attribution)

    Example:
        >>> metric = HeatmapGini()
        >>> gini = metric.compute(heatmap=heatmap)
        >>> print(f"Gini coefficient: {gini:.4f}")
    """

    def validate_inputs(
        self,
        heatmap: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs."""
        if heatmap is None:
            raise InvalidMetricInputError(
                "heatmap_gini",
                "heatmap",
                "Tensor [H, W]",
                "None",
            )

    @no_grad
    @timed
    def compute(
        self,
        heatmap: "Tensor",
        **kwargs,
    ) -> float:
        """Compute the Gini coefficient.

        Args:
            heatmap: Attribution heatmap [H, W] or [B, H, W].

        Returns:
            Gini coefficient (0 to 1).
        """
        self.validate_inputs(heatmap=heatmap)

        if heatmap.ndim == 3:
            heatmap = heatmap.squeeze(0)

        flat = heatmap.flatten()
        flat = torch.relu(flat)  # Handle negative values

        n = flat.numel()
        if n <= 1:
            return 0.0

        # Sort values
        sorted_vals, _ = torch.sort(flat)

        # Compute Gini coefficient
        # G = (2 * Σ i * x_i) / (n * Σ x_i) - (n + 1) / n
        indices = torch.arange(1, n + 1, dtype=flat.dtype, device=flat.device)
        total = sorted_vals.sum()

        if total < 1e-10:
            return 0.0

        gini = (2 * torch.sum(indices * sorted_vals) / (n * total)) - (n + 1) / n

        return float(gini.item())


def compute_heatmap_stats(heatmap: "Tensor") -> dict[str, float]:
    """Compute comprehensive heatmap statistics.

    Args:
        heatmap: Attribution heatmap [H, W].

    Returns:
        Dictionary with:
            - entropy: Normalized entropy
            - gini: Gini coefficient
            - mean: Mean attribution value
            - std: Standard deviation
            - min: Minimum value
            - max: Maximum value
            - positive_area: Fraction of positive pixels
            - peak_ratio: max / (max - min), focus measure

    Example:
        >>> stats = compute_heatmap_stats(heatmap)
        >>> print(f"Entropy: {stats['entropy']:.4f}, Gini: {stats['gini']:.4f}")
    """
    entropy_metric = HeatmapEntropy()
    gini_metric = HeatmapGini()

    if heatmap.ndim == 3:
        heatmap = heatmap.squeeze(0)

    flat = heatmap.flatten()
    max_val = flat.max()
    min_val = flat.min()
    range_val = max_val - min_val

    return {
        "entropy": entropy_metric.compute(heatmap=heatmap),
        "gini": gini_metric.compute(heatmap=heatmap),
        "mean": float(flat.mean().item()),
        "std": float(flat.std().item()),
        "min": float(min_val.item()),
        "max": float(max_val.item()),
        "positive_area": float((flat > 0).sum().item() / flat.numel()),
        "peak_ratio": float((max_val / (range_val + 1e-10)).item()),
    }
