"""Probe callback protocol for intermediate value collection.

This module defines the interface that all probe implementations must follow.
The protocol defines callback methods for each stage of E-GradCAM computation.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

if TYPE_CHECKING:
    from torch import Tensor


@runtime_checkable
class ProbeCallback(Protocol):
    """Protocol for probe callbacks.

    Probe callbacks are invoked at various stages of E-GradCAM computation
    to collect intermediate values for analysis.

    Implementations can choose to:
    - Store values immediately (SyncProbe)
    - Queue values for background storage (AsyncProbe)
    - Aggregate values in memory
    - Filter/downsample values

    All methods receive metadata that can be used for organizing stored data.

    Example implementing a custom probe:
        >>> class MyProbe:
        ...     def on_perturbation_sampled(self, I_samples, metadata):
        ...         print(f"Sampled {I_samples.shape[0]} perturbations")
        ...
        ...     def on_attribution_computed(self, phi_samples, M_I, eigenvalues, metadata):
        ...         print(f"Computed attributions, condition number: {eigenvalues.max()/eigenvalues.min()}")
        ...
        ...     def on_weights_computed(self, alpha_opt, solver_diagnostics, metadata):
        ...         print(f"Optimal weights: mean={alpha_opt.mean():.4f}")
        ...
        ...     def flush(self):
        ...         pass
    """

    def on_perturbation_sampled(
        self,
        I_samples: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Called when perturbation samples are generated.

        Args:
            I_samples: Perturbation samples [M, K].
            metadata: Additional context (e.g., sample_id, config).
        """
        ...

    def on_attribution_computed(
        self,
        phi_samples: "Tensor",
        M_I: "Tensor",
        eigenvalues: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Called when attributions are computed.

        Args:
            phi_samples: Attribution samples [M, K].
            M_I: Second moment matrix [K, K].
            eigenvalues: Eigenvalues of M_I [K].
            metadata: Additional context.
        """
        ...

    def on_weights_computed(
        self,
        alpha_opt: "Tensor",
        solver_diagnostics: dict[str, Any] | None,
        metadata: dict[str, Any],
    ) -> None:
        """Called when optimal weights are computed.

        Args:
            alpha_opt: Optimal weights [K].
            solver_diagnostics: Solver information (condition number, rank, etc.).
            metadata: Additional context.
        """
        ...

    def on_heatmap_generated(
        self,
        heatmap: "Tensor",
        raw_heatmap: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Called when heatmap is generated.

        Args:
            heatmap: Final normalized heatmap [H, W].
            raw_heatmap: Pre-normalization heatmap [H, W].
            metadata: Additional context.
        """
        ...

    def flush(self) -> None:
        """Flush any buffered data.

        Called to ensure all collected data is persisted.
        For async probes, this waits for the queue to empty.
        """
        ...


class NullProbe:
    """A no-op probe that discards all data.

    Useful as a default when no probe is configured.
    """

    def on_perturbation_sampled(
        self,
        I_samples: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        pass

    def on_attribution_computed(
        self,
        phi_samples: "Tensor",
        M_I: "Tensor",
        eigenvalues: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        pass

    def on_weights_computed(
        self,
        alpha_opt: "Tensor",
        solver_diagnostics: dict[str, Any] | None,
        metadata: dict[str, Any],
    ) -> None:
        pass

    def on_heatmap_generated(
        self,
        heatmap: "Tensor",
        raw_heatmap: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        pass

    def flush(self) -> None:
        pass
