"""Internal (feature-space) infidelity metric.

This module implements the internal infidelity metric computed in feature
map space. This is the primary infidelity metric for Expected GradCAM because
it can be computed without additional forward passes.

Mathematical Specification:
    INFD(α) = E[(I^T α - (g(z_0) - g(z_0 - I)))²]

    where:
    - α: Optimal weights [K]
    - I: Perturbation in feature space [K] (from I_samples [M, K])
    - g(z_0): Model output at reference point (scalar)
    - g(z_0 - I): Model output after perturbation [M]

The infidelity measures how well the linear approximation (I^T α) predicts
the actual change in model output (g(z_0) - g(z_0 - I)).

Key Insight:
    This metric can be computed **without any additional forward passes**
    because I_samples and g_perturbed are already computed during the
    main Expected GradCAM algorithm.

Example:
    >>> from expected_gradcam.metrics.infidelity import InternalInfidelity
    >>>
    >>> metric = InternalInfidelity()
    >>> infidelity = metric.compute(
    ...     alpha=alpha,           # [K] from solver
    ...     I_samples=I_samples,   # [M, K] from sampler
    ...     g_z0=g_z0,             # scalar from predictor
    ...     g_perturbed=g_perturbed,  # [M] from predictor
    ... )
    >>> print(f"Infidelity: {infidelity:.6f}")
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from expected_gradcam.metrics.base import (
    BaseMetric,
    StreamingMetric,
    check_finite,
    no_grad,
    timed,
    validate_inputs,
)
from expected_gradcam.metrics.exceptions import (
    InfidelityComputationError,
    InsufficientSamplesError,
    InvalidMetricInputError,
)
from expected_gradcam.metrics.registry import register_metric

if TYPE_CHECKING:
    from torch import Tensor


@register_metric(
    "internal_infidelity",
    display_name="Internal Infidelity",
    lower_is_better=True,
    streamable=True,
    category="infidelity",
)
class InternalInfidelity(BaseMetric):
    """Feature-space infidelity metric.

    Computes how well the linear approximation (I^T α) predicts the
    actual change in model output when perturbation I is applied.

    This metric is essentially FREE to compute because it reuses values
    already computed during the Expected GradCAM algorithm:
    - I_samples: Feature-space perturbations from the sampler
    - g_z0, g_perturbed: Model outputs from the predictor

    Attributes:
        min_samples: Minimum samples required for reliable computation.

    Example:
        >>> metric = InternalInfidelity()
        >>> infidelity = metric.compute(
        ...     alpha=alpha,
        ...     I_samples=I_samples,
        ...     g_z0=g_z0,
        ...     g_perturbed=g_perturbed,
        ... )
    """

    min_samples: int = 10

    def __init__(self, min_samples: int = 10) -> None:
        """Initialize the metric.

        Args:
            min_samples: Minimum number of samples required for computation.
        """
        self.min_samples = min_samples

    def validate_inputs(
        self,
        alpha: "Tensor | None" = None,
        I_samples: "Tensor | None" = None,
        g_z0: "Tensor | float | None" = None,
        g_perturbed: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs for infidelity computation.

        Args:
            alpha: Optimal weights [K].
            I_samples: Perturbation samples [M, K].
            g_z0: Reference model output (scalar).
            g_perturbed: Perturbed model outputs [M].

        Raises:
            InvalidMetricInputError: If inputs are missing or have wrong shape.
            InsufficientSamplesError: If not enough samples provided.
        """
        if alpha is None:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "alpha",
                "Tensor [K]",
                "None",
                suggestion="Pass alpha weights from the solver",
            )

        if I_samples is None:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "I_samples",
                "Tensor [M, K]",
                "None",
                suggestion="Pass I_samples from the perturbation sampler",
            )

        if g_z0 is None:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "g_z0",
                "Tensor or float (scalar)",
                "None",
                suggestion="Pass g_z0 (reference output) from the predictor",
            )

        if g_perturbed is None:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "g_perturbed",
                "Tensor [M]",
                "None",
                suggestion="Pass g_perturbed from the predictor",
            )

        # Shape validation
        K = alpha.shape[0]
        M = I_samples.shape[0]

        if I_samples.ndim != 2:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "I_samples",
                f"2D tensor [M, K]",
                f"{I_samples.ndim}D tensor",
            )

        if I_samples.shape[1] != K:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "I_samples",
                f"shape [M, {K}] (matching alpha)",
                f"shape {list(I_samples.shape)}",
            )

        if g_perturbed.ndim != 1:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "g_perturbed",
                "1D tensor [M]",
                f"{g_perturbed.ndim}D tensor",
            )

        if g_perturbed.shape[0] != M:
            raise InvalidMetricInputError(
                "internal_infidelity",
                "g_perturbed",
                f"shape [{M}] (matching I_samples)",
                f"shape {list(g_perturbed.shape)}",
            )

        # Sample count validation
        if M < self.min_samples:
            raise InsufficientSamplesError(
                "internal_infidelity",
                required=self.min_samples,
                actual=M,
            )

    @no_grad
    @timed
    @check_finite
    def compute(
        self,
        alpha: "Tensor",
        I_samples: "Tensor",
        g_z0: "Tensor | float",
        g_perturbed: "Tensor",
        **kwargs,
    ) -> float:
        """Compute internal (feature-space) infidelity.

        Mathematical formula:
            INFD(α) = E[(I^T α - (g(z_0) - g(z_0 - I)))²]

        This is essentially the mean squared error between:
        - Predicted change: I^T @ α (linear approximation)
        - Actual change: g(z_0) - g(z_0 - I)

        Args:
            alpha: Optimal weights [K] from the solver.
            I_samples: Perturbation samples [M, K] from the sampler.
            g_z0: Reference model output g(z_0) (scalar).
            g_perturbed: Perturbed model outputs g(z_0 - I) [M].

        Returns:
            Infidelity value (lower is better).

        Raises:
            InfidelityComputationError: If computation fails (NaN, etc.).
        """
        self.validate_inputs(
            alpha=alpha,
            I_samples=I_samples,
            g_z0=g_z0,
            g_perturbed=g_perturbed,
        )

        # Ensure g_z0 is a tensor
        if isinstance(g_z0, (int, float)):
            g_z0 = torch.tensor(g_z0, dtype=I_samples.dtype, device=I_samples.device)

        # Predicted change: I^T @ α for each sample
        # I_samples is [M, K], alpha is [K], result is [M]
        predicted_change = torch.mv(I_samples, alpha)

        # Actual change: g(z_0) - g(z_0 - I)
        actual_change = g_z0 - g_perturbed

        # Mean squared error
        squared_error = (predicted_change - actual_change) ** 2
        infidelity = float(squared_error.mean().item())

        # Check for numerical issues
        if torch.isnan(torch.tensor(infidelity)) or torch.isinf(torch.tensor(infidelity)):
            raise InfidelityComputationError(
                "Computed value is NaN or Inf",
                suggestion="Check that alpha and I_samples are finite",
            )

        return infidelity


@register_metric(
    "streaming_infidelity",
    display_name="Streaming Infidelity",
    lower_is_better=True,
    streamable=True,
    category="infidelity",
)
class StreamingInfidelity(StreamingMetric):
    """Streaming (incremental) infidelity metric.

    Computes infidelity incrementally as samples arrive, useful for
    real-time visualization during computation.

    Example:
        >>> metric = StreamingInfidelity()
        >>> for batch in batches:
        ...     metric.update({
        ...         "alpha": alpha,
        ...         "I_batch": batch_I,
        ...         "g_z0": g_z0,
        ...         "g_perturbed_batch": batch_g,
        ...     })
        ...     print(f"Current infidelity: {metric.get_current_value():.6f}")
        >>> metric.reset()
    """

    def __init__(self) -> None:
        """Initialize the streaming metric."""
        super().__init__()
        self._squared_errors: list[float] = []
        self._sample_count: int = 0

    def update(self, batch_data: dict) -> None:
        """Update with a batch of data.

        Args:
            batch_data: Dictionary with keys:
                - alpha: Optimal weights [K]
                - I_batch: Perturbation batch [B, K]
                - g_z0: Reference output (scalar)
                - g_perturbed_batch: Perturbed outputs [B]
        """
        alpha = batch_data["alpha"]
        I_batch = batch_data["I_batch"]
        g_z0 = batch_data["g_z0"]
        g_perturbed_batch = batch_data["g_perturbed_batch"]

        # Ensure g_z0 is a tensor
        if isinstance(g_z0, (int, float)):
            g_z0 = torch.tensor(g_z0, dtype=I_batch.dtype, device=I_batch.device)

        # Compute for this batch
        with torch.no_grad():
            predicted_change = torch.mv(I_batch, alpha)
            actual_change = g_z0 - g_perturbed_batch
            squared_error = (predicted_change - actual_change) ** 2

            # Store batch statistics
            batch_mse = float(squared_error.mean().item())
            batch_size = I_batch.shape[0]

            self._squared_errors.append(batch_mse * batch_size)
            self._sample_count += batch_size

    def get_current_value(self) -> float:
        """Get current infidelity estimate.

        Returns:
            Current mean squared error across all samples.
        """
        if self._sample_count == 0:
            return 0.0
        return sum(self._squared_errors) / self._sample_count

    def reset(self) -> None:
        """Reset the streaming state."""
        self._squared_errors.clear()
        self._sample_count = 0


def compute_internal_infidelity(
    alpha: "Tensor",
    I_samples: "Tensor",
    g_z0: "Tensor | float",
    g_perturbed: "Tensor",
) -> float:
    """Convenience function to compute internal infidelity.

    This is a simple function interface for computing infidelity
    without instantiating a metric object.

    Mathematical formula:
        INFD(α) = E[(I^T α - (g(z_0) - g(z_0 - I)))²]

    Args:
        alpha: Optimal weights [K].
        I_samples: Perturbation samples [M, K].
        g_z0: Reference model output g(z_0) (scalar).
        g_perturbed: Perturbed model outputs g(z_0 - I) [M].

    Returns:
        Infidelity value (lower is better).

    Example:
        >>> infidelity = compute_internal_infidelity(
        ...     alpha=alpha,
        ...     I_samples=I_samples,
        ...     g_z0=g_z0,
        ...     g_perturbed=g_perturbed,
        ... )
    """
    metric = InternalInfidelity()
    return metric.compute(
        alpha=alpha,
        I_samples=I_samples,
        g_z0=g_z0,
        g_perturbed=g_perturbed,
    )
