"""Predictor function g(z; A) for Expected GradCAM.

The predictor evaluates the class score when feature maps are scaled by z:
    g(z; A) = y^c(z_1 * A^1, z_2 * A^2, ..., z_K * A^K)

At reference point z_0 = (1, 1, ..., 1), this equals the original model output.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch import nn


if TYPE_CHECKING:
    from torch import Tensor


class Predictor:
    """Predictor function g(z; A) for Expected GradCAM.

    Evaluates the class score when feature maps A^k are scaled by multipliers z_k.

    Mathematical specification:
        g(z; A) = y^c(z_1 * A^1, z_2 * A^2, ..., z_K * A^K)

    where:
        - z ∈ R^K is the vector of feature map multipliers
        - A = (A^1, ..., A^K) are feature maps, each A^k ∈ R^{U × V}
        - y^c is the logit (pre-softmax score) for target class c

    At reference point z_0 = (1, 1, ..., 1), g(z_0; A) equals the original
    model output for class c.

    Attributes:
        classifier_head: Module mapping feature maps to class logits.
        target_class: Target class index c.
        feature_maps: Feature maps A at target layer [B, K, U, V].
        num_features: Number of feature channels K.

    Example:
        >>> predictor = Predictor(classifier_head, target_class=243, feature_maps=A)
        >>> z = torch.ones(K, device=A.device, requires_grad=True)
        >>> score = predictor(z)  # Shape: [B]
    """

    def __init__(
        self,
        classifier_head: nn.Module,
        target_class: int,
        feature_maps: "Tensor",
    ) -> None:
        """Initialize the predictor.

        Args:
            classifier_head: Module that maps [B, K, U, V] -> [B, num_classes].
            target_class: Target class index c.
            feature_maps: Feature maps A from target layer [B, K, U, V].
        """
        self.classifier_head = classifier_head
        self.target_class = target_class
        self.feature_maps = feature_maps
        self._K = feature_maps.shape[1]

    @property
    def num_features(self) -> int:
        """Number of feature maps K."""
        return self._K

    @property
    def device(self) -> torch.device:
        """Device where feature maps are stored."""
        return self.feature_maps.device

    def __call__(self, z: "Tensor") -> "Tensor":
        """Evaluate g(z; A) = y^c(z_1 * A^1, ..., z_K * A^K).

        Args:
            z: Feature map multipliers [K] or [B, K].

        Returns:
            Class score y^c with shape [B].

        Note:
            For gradient computation, z should have requires_grad=True.
        """
        # Handle different z shapes
        if z.dim() == 1:
            # z is [K], broadcast to [1, K, 1, 1]
            z_expanded = z.view(1, -1, 1, 1)
        elif z.dim() == 2:
            # z is [B, K], expand to [B, K, 1, 1]
            z_expanded = z.unsqueeze(-1).unsqueeze(-1)
        else:
            raise ValueError(f"z must be 1D or 2D, got {z.dim()}D")

        # Scale feature maps: [B, K, U, V] * [1/B, K, 1, 1] -> [B, K, U, V]
        scaled_A = self.feature_maps * z_expanded

        # Forward through classifier head -> [B, num_classes]
        logits = self.classifier_head(scaled_A)

        # Extract target class score -> [B]
        return logits[:, self.target_class]

    def evaluate_at_reference(self) -> "Tensor":
        """Evaluate g(z_0; A) where z_0 = (1, 1, ..., 1).

        Returns:
            Class score at reference point [B].
        """
        z0 = torch.ones(self._K, device=self.device)
        with torch.no_grad():
            return self(z0)

    def evaluate_at_baseline(self, I: "Tensor") -> "Tensor":
        """Evaluate g(z_0 - I; A) at perturbed point.

        Args:
            I: Perturbation vector [K].

        Returns:
            Class score at perturbed point [B].
        """
        z0 = torch.ones(self._K, device=self.device)
        z_perturbed = z0 - I
        with torch.no_grad():
            return self(z_perturbed)

    def compute_output_difference(self, I: "Tensor") -> "Tensor":
        """Compute Δg(I) = g(z_0; A) - g(z_0 - I; A).

        This is the actual output change the attribution should predict.

        Args:
            I: Perturbation vector [K].

        Returns:
            Output difference Δg(I) [B].
        """
        g_z0 = self.evaluate_at_reference()
        g_perturbed = self.evaluate_at_baseline(I)
        return g_z0 - g_perturbed


class BatchedPredictor:
    """Batched predictor for efficient computation over multiple z values.

    Processes multiple z vectors in a single forward pass, which is more
    efficient for path integration and gradient computation.

    Supports torch.compile for 1.3-2x speedup on GPU.

    Attributes:
        classifier_head: Module mapping feature maps to class logits.
        target_class: Target class index c.
        feature_maps: Feature maps A [1, K, U, V] (batch size must be 1).
        num_features: Number of feature channels K.

    Example:
        >>> predictor = BatchedPredictor(classifier_head, target_class=243, feature_maps=A)
        >>> z_batch = torch.randn(100, K, device=A.device, requires_grad=True)
        >>> scores = predictor(z_batch)  # Shape: [100]
    """

    def __init__(
        self,
        classifier_head: nn.Module,
        target_class: int,
        feature_maps: "Tensor",
        use_compile: bool = False,
    ) -> None:
        """Initialize batched predictor.

        Args:
            classifier_head: Module mapping [N, K, U, V] -> [N, num_classes].
            target_class: Target class index c.
            feature_maps: Feature maps A [1, K, U, V]. Must have batch size 1.
            use_compile: Whether to use torch.compile for speedup.

        Raises:
            ValueError: If feature_maps batch size is not 1.
        """
        if feature_maps.shape[0] != 1:
            raise ValueError(
                f"BatchedPredictor requires batch size 1, got {feature_maps.shape[0]}"
            )

        self.classifier_head = classifier_head
        self.target_class = target_class
        self.feature_maps = feature_maps
        self._K = feature_maps.shape[1]
        self._use_compile = use_compile and torch.cuda.is_available()
        self._compiled_forward: nn.Module | None = None

    @property
    def num_features(self) -> int:
        """Number of feature maps K."""
        return self._K

    @property
    def device(self) -> torch.device:
        """Device where feature maps are stored."""
        return self.feature_maps.device

    def _forward_impl(self, z_batch: "Tensor") -> "Tensor":
        """Core forward implementation.

        Args:
            z_batch: Batch of multipliers [N, K].

        Returns:
            Class scores [N].
        """
        N = z_batch.shape[0]

        # Expand feature maps to batch: [1, K, U, V] -> [N, K, U, V]
        A_expanded = self.feature_maps.expand(N, -1, -1, -1)

        # Scale: [N, K, U, V] * [N, K, 1, 1] -> [N, K, U, V]
        z_expanded = z_batch.unsqueeze(-1).unsqueeze(-1)
        scaled_A = A_expanded * z_expanded

        # Forward -> [N, num_classes]
        logits = self.classifier_head(scaled_A)

        # Extract target class -> [N]
        return logits[:, self.target_class]

    def _get_compiled_forward(self) -> nn.Module:
        """Get or create compiled forward function."""
        if self._compiled_forward is None:
            try:
                self._compiled_forward = torch.compile(
                    self._forward_impl,
                    mode="reduce-overhead",
                    dynamic=True,
                )
            except Exception:
                # Fall back to non-compiled version
                self._compiled_forward = self._forward_impl  # type: ignore
        return self._compiled_forward  # type: ignore

    def __call__(self, z_batch: "Tensor") -> "Tensor":
        """Evaluate g(z; A) for a batch of z vectors.

        Args:
            z_batch: Batch of multipliers [N, K] or single vector [K].

        Returns:
            Class scores [N] or [1] for single vector.

        Note:
            For 1D input (single z vector), the tensor is automatically
            unsqueezed to [1, K] for consistency with batch processing.
            This ensures gradients are computed correctly when used with
            path integration methods like ExpectedGradients.
        """
        # Handle 1D input (single z vector) - fixes completeness axiom bug
        # where 1D [K] was misinterpreted as K separate samples
        if z_batch.dim() == 1:
            z_batch = z_batch.unsqueeze(0)  # [K] -> [1, K]

        if self._use_compile:
            return self._get_compiled_forward()(z_batch)
        return self._forward_impl(z_batch)

    def evaluate_at_reference(self) -> "Tensor":
        """Evaluate g(z_0; A) where z_0 = (1, ..., 1).

        Returns:
            Class score at reference point [1].
        """
        z0 = torch.ones(1, self._K, device=self.device)
        with torch.no_grad():
            return self(z0)

    def evaluate_at_baseline(self, I: "Tensor") -> "Tensor":
        """Evaluate g(z_0 - I; A).

        Args:
            I: Perturbation vector [K].

        Returns:
            Class score at perturbed point [1].
        """
        z0 = torch.ones(self._K, device=self.device)
        z_perturbed = z0 - I
        with torch.no_grad():
            return self(z_perturbed)

    def compute_output_difference(self, I: "Tensor") -> "Tensor":
        """Compute Δg(I) = g(z_0; A) - g(z_0 - I; A).

        This is the actual output change the attribution should predict.

        Args:
            I: Perturbation vector [K].

        Returns:
            Output difference Δg(I) [1].
        """
        g_z0 = self.evaluate_at_reference().clone()
        g_perturbed = self.evaluate_at_baseline(I).clone()
        return g_z0 - g_perturbed
