"""GPU-batched infidelity metric (input space).

This module implements a GPU-batched infidelity metric that operates in
input (image) space using random patch perturbations. While this is an
alternative to the internal (feature-space) infidelity, it requires
additional forward passes.

Note:
    The internal (feature-space) infidelity in internal.py is preferred
    because it can be computed without additional forward passes.
    Use this batched version only when input-space infidelity is specifically
    required (e.g., for Quantus compatibility).

Example:
    >>> from expected_gradcam.metrics.infidelity import BatchedInfidelity
    >>>
    >>> metric = BatchedInfidelity(n_perturbations=100)
    >>> infidelity, info = metric.compute(
    ...     model=model,
    ...     x_batch=images,
    ...     y_batch=labels,
    ...     heatmaps=heatmaps,
    ... )
"""

from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any

import torch
import torch.nn.functional as F
from torch import nn

from expected_gradcam.metrics.base import BaseMetric, no_grad, timed
from expected_gradcam.metrics.config import MetricConfig
from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.registry import register_metric

if TYPE_CHECKING:
    from torch import Tensor


@register_metric(
    "batched_infidelity",
    display_name="GPU-Batched Infidelity",
    lower_is_better=True,
    streamable=False,
    category="infidelity",
)
class BatchedInfidelity(BaseMetric):
    """GPU-batched infidelity using input-space perturbations.

    This metric computes infidelity using random patch perturbations
    in the input (image) space. It batches all perturbations into
    minimal forward passes for GPU efficiency.

    Performance: ~40,000x faster than naive per-pixel perturbation.

    Note:
        This requires additional forward passes through the model.
        For zero-cost infidelity, use InternalInfidelity instead.

    Attributes:
        config: Configuration for metric computation.

    Example:
        >>> config = MetricConfig(n_perturbations=100, patch_size=8)
        >>> metric = BatchedInfidelity(config)
        >>> infidelity, info = metric.compute(
        ...     model=model,
        ...     x_batch=images,     # [N, C, H, W]
        ...     y_batch=labels,     # [N]
        ...     heatmaps=heatmaps,  # [N, H, W]
        ... )
    """

    def __init__(self, config: MetricConfig | None = None) -> None:
        """Initialize the batched infidelity metric.

        Args:
            config: Configuration for computation. Uses defaults if None.
        """
        self.config = config or MetricConfig()

    def validate_inputs(
        self,
        model: nn.Module | None = None,
        x_batch: "Tensor | None" = None,
        y_batch: "Tensor | None" = None,
        heatmaps: "Tensor | None" = None,
        **kwargs,
    ) -> None:
        """Validate inputs for batched infidelity computation.

        Args:
            model: The neural network model.
            x_batch: Input images [N, C, H, W].
            y_batch: Target class indices [N].
            heatmaps: Attribution heatmaps [N, H, W].

        Raises:
            InvalidMetricInputError: If inputs are invalid.
        """
        if model is None:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "model",
                "nn.Module",
                "None",
            )

        if x_batch is None:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "x_batch",
                "Tensor [N, C, H, W]",
                "None",
            )

        if y_batch is None:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "y_batch",
                "Tensor [N]",
                "None",
            )

        if heatmaps is None:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "heatmaps",
                "Tensor [N, H, W]",
                "None",
            )

        # Shape validation
        if x_batch.ndim != 4:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "x_batch",
                "4D tensor [N, C, H, W]",
                f"{x_batch.ndim}D tensor",
            )

        N = x_batch.shape[0]
        H, W = x_batch.shape[2], x_batch.shape[3]

        if y_batch.shape[0] != N:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "y_batch",
                f"shape [{N}]",
                f"shape {list(y_batch.shape)}",
            )

        if heatmaps.shape[0] != N or heatmaps.shape[1] != H or heatmaps.shape[2] != W:
            raise InvalidMetricInputError(
                "batched_infidelity",
                "heatmaps",
                f"shape [{N}, {H}, {W}]",
                f"shape {list(heatmaps.shape)}",
            )

    @no_grad
    @timed
    def compute(
        self,
        model: nn.Module,
        x_batch: "Tensor",
        y_batch: "Tensor",
        heatmaps: "Tensor",
        n_perturb: int | None = None,
        patch_size: int | None = None,
        perturb_batch_size: int = 50,
        **kwargs,
    ) -> tuple["Tensor", dict[str, Any]]:
        """Compute batched infidelity in input space.

        Args:
            model: The neural network model.
            x_batch: Input images [N, C, H, W].
            y_batch: Target class indices [N].
            heatmaps: Attribution heatmaps [N, H, W].
            n_perturb: Number of perturbations (default from config).
            patch_size: Patch size for masks (default from config).
            perturb_batch_size: Batch size for perturbation processing.

        Returns:
            Tuple of:
                - Infidelity scores per sample [N]
                - Info dict with timing and statistics
        """
        self.validate_inputs(
            model=model,
            x_batch=x_batch,
            y_batch=y_batch,
            heatmaps=heatmaps,
        )

        n_perturb = n_perturb or self.config.n_perturbations
        patch_size = patch_size or self.config.patch_size

        device = x_batch.device
        N, C, H, W = x_batch.shape

        n_patches_h = H // patch_size
        n_patches_w = W // patch_size

        start_time = time.time()

        # Baseline predictions
        model.eval()
        amp_dtype = torch.float16 if self.config.use_amp else torch.float32

        with torch.amp.autocast(device_type=device.type, dtype=amp_dtype):
            baseline_logits = model(x_batch)
            baseline_probs = F.softmax(baseline_logits, dim=1)
            baseline_scores = baseline_probs.gather(1, y_batch.unsqueeze(1)).squeeze(1)

        # Accumulate squared errors
        infidelities = torch.zeros(N, device=device)
        n_forward_passes = 1  # baseline

        # Mean values for replacement
        mean_vals = x_batch.mean(dim=(2, 3), keepdim=True)

        # Process perturbations in batches
        for p_start in range(0, n_perturb, perturb_batch_size):
            p_end = min(p_start + perturb_batch_size, n_perturb)
            P = p_end - p_start

            # Generate random patch masks
            patch_masks = (
                torch.rand(P, N, n_patches_h, n_patches_w, device=device)
                < self.config.perturbation_probability
            ).float()

            # Upsample masks to image size
            masks = F.interpolate(
                patch_masks.view(P * N, 1, n_patches_h, n_patches_w),
                size=(H, W),
                mode="nearest",
            ).view(P, N, H, W)

            # Create perturbed images
            x_expanded = x_batch.unsqueeze(0).expand(P, -1, -1, -1, -1)
            masks_expanded = masks.unsqueeze(2)

            x_perturbed = (
                x_expanded * (1 - masks_expanded) + mean_vals.unsqueeze(0) * masks_expanded
            )
            x_perturbed = x_perturbed.view(P * N, C, H, W)

            # Forward pass
            with torch.amp.autocast(device_type=device.type, dtype=amp_dtype):
                perturbed_logits = model(x_perturbed)
                perturbed_probs = F.softmax(perturbed_logits, dim=1)

                y_expanded = y_batch.unsqueeze(0).expand(P, -1).reshape(P * N)
                perturbed_scores = perturbed_probs.gather(1, y_expanded.unsqueeze(1)).squeeze(1)
                perturbed_scores = perturbed_scores.view(P, N)

            n_forward_passes += 1

            # Compute changes
            actual_change = baseline_scores.unsqueeze(0) - perturbed_scores  # [P, N]
            heatmaps_expanded = heatmaps.unsqueeze(0)  # [1, N, H, W]
            predicted_change = (heatmaps_expanded * masks).sum(dim=(-2, -1))  # [P, N]

            # Accumulate squared errors
            infidelities += ((predicted_change - actual_change) ** 2).sum(dim=0)

        # Average over perturbations
        infidelities = infidelities / n_perturb
        elapsed = time.time() - start_time

        info = {
            "time": elapsed,
            "n_forward_passes": n_forward_passes,
            "n_samples": N,
            "n_perturb": n_perturb,
            "throughput_samples_per_sec": N / elapsed,
            "patch_size": patch_size,
        }

        return infidelities, info

    def __call__(
        self,
        model: nn.Module,
        x_batch: "Tensor",
        y_batch: "Tensor",
        heatmaps: "Tensor",
        **kwargs,
    ) -> tuple["Tensor", dict[str, Any]]:
        """Call the metric as a function."""
        return self.compute(
            model=model,
            x_batch=x_batch,
            y_batch=y_batch,
            heatmaps=heatmaps,
            **kwargs,
        )
