"""Protocol definitions for baseline providers.

This module defines the core protocols (interfaces) that all baseline providers
must implement. Using Protocols instead of abstract base classes allows for
structural subtyping (duck typing with type safety).

Protocols:
    BaselineProvider: Core interface for all providers
    CacheableProvider: Extended interface for providers with caching
    StreamingProvider: Extended interface for memory-efficient streaming
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Iterator, Protocol, runtime_checkable

if TYPE_CHECKING:
    import torch
    from torch import Tensor, nn


@runtime_checkable
class BaselineProvider(Protocol):
    """Protocol for baseline data providers.

    A BaselineProvider is responsible for:
    1. Providing access to baseline images or features
    2. Validating compatibility with model/layer configuration
    3. Supporting lazy loading for efficiency
    4. Providing iteration over samples

    Implementations can provide:
    - Raw images (requires forward passes during initialization)
    - Pre-extracted GAP features (most efficient for sampling)

    All providers MUST ensure that samples returned by get_baseline_samples()
    are centered (E[z'] = 0) to satisfy the Expected Gradients completeness axiom.

    Example implementation::

        class MyProvider:
            @property
            def is_initialized(self) -> bool:
                return self._data is not None

            @property
            def provider_type(self) -> str:
                return "my_provider"

            def initialize(self, model, target_layer, device) -> None:
                self._data = load_data()

            def validate(self, model, target_layer) -> None:
                if not compatible(self._data, model):
                    raise IncompatibleDatasetError(...)

            def get_baseline_samples(self, n, device) -> Tensor:
                return sample(self._data, n).to(device)

            def __len__(self) -> int:
                return len(self._data)
    """

    @property
    def is_initialized(self) -> bool:
        """Check if provider is initialized and ready to provide data.

        Returns:
            True if initialize() has been called and completed successfully.
        """
        ...

    @property
    def provider_type(self) -> str:
        """Return the type identifier for this provider.

        This should match the name used in the provider registry.

        Returns:
            Provider type string (e.g., "directory", "huggingface").
        """
        ...

    def initialize(
        self,
        model: "nn.Module",
        target_layer: "nn.Module",
        device: "torch.device",
    ) -> None:
        """Initialize the provider with model context.

        This method is called lazily when baselines are first needed.
        It should:
        1. Load or prepare data
        2. Extract GAP features if working with raw images
        3. Set up any caching

        Args:
            model: The CNN model for feature extraction.
            target_layer: Layer from which to extract features.
            device: Computation device (cpu, cuda, etc.).

        Raises:
            ProviderInitializationError: If initialization fails.
        """
        ...

    def validate(
        self,
        model: "nn.Module",
        target_layer: "nn.Module",
    ) -> None:
        """Validate that provider data is compatible with model.

        Should check:
        - Feature dimensions match expected layer output
        - Data can be successfully processed through the model

        Args:
            model: The CNN model.
            target_layer: Layer for feature extraction.

        Raises:
            DimensionMismatchError: If feature dimensions don't match.
            BaselineValidationError: For other validation issues.
        """
        ...

    def get_baseline_samples(
        self,
        n: int,
        device: "torch.device",
    ) -> "Tensor":
        """Get n baseline samples as centered GAP vectors.

        Returns centered baseline samples suitable for Expected Gradients.
        The samples MUST have E[z'] = 0 for the completeness axiom.

        Args:
            n: Number of samples to return.
            device: Device for output tensor.

        Returns:
            Tensor of shape [n, K] where K is number of feature channels.
            Samples are centered (mean = 0 along dimension 0).

        Raises:
            RuntimeError: If provider not initialized.
            InsufficientSamplesError: If n > len(self) and replacement not allowed.
        """
        ...

    def __len__(self) -> int:
        """Return total number of available samples.

        Returns:
            Number of baseline samples available.
        """
        ...

    def __iter__(self) -> "Iterator[Tensor]":
        """Iterate over all baseline samples.

        Yields:
            Individual baseline vectors of shape [K].
        """
        ...


@runtime_checkable
class CacheableProvider(BaselineProvider, Protocol):
    """Extended protocol for providers that support feature caching.

    Adds methods for saving and loading pre-extracted features to/from disk.
    This enables faster initialization on subsequent runs.

    Example::

        provider = DirectoryProvider("/data/imagenet/train")
        provider.initialize(model, layer, device)

        # Save extracted features
        provider.save_cache("/cache/imagenet_features.npy")

        # Later, load from cache (much faster)
        provider.load_cache("/cache/imagenet_features.npy")
    """

    def save_cache(self, cache_path: str) -> None:
        """Save extracted features to cache file.

        Args:
            cache_path: Path to save cache (typically .npy or .npz).

        Raises:
            CacheSizeExceededError: If cache would exceed disk space.
            RuntimeError: If provider not initialized.
        """
        ...

    def load_cache(self, cache_path: str) -> bool:
        """Load features from cache file.

        If cache exists and is valid, loads features and marks provider
        as initialized. If cache doesn't exist or is invalid, returns False.

        Args:
            cache_path: Path to cache file.

        Returns:
            True if cache was loaded successfully, False otherwise.

        Raises:
            CacheCorruptedError: If cache file is corrupted.
        """
        ...

    @property
    def cache_path(self) -> str | None:
        """Get current cache path if set.

        Returns:
            Path to cache file, or None if no cache configured.
        """
        ...


@runtime_checkable
class StreamingProvider(BaselineProvider, Protocol):
    """Extended protocol for providers that support streaming/batched access.

    For large datasets that cannot fit in memory, this protocol provides
    memory-efficient iteration over batches without loading all data at once.

    Example::

        provider = HuggingFaceProvider("imagenet-1k", streaming=True)
        provider.initialize(model, layer, device)

        for batch in provider.iter_batches(batch_size=32, device=device):
            # batch is [32, K] tensor
            process_batch(batch)
    """

    def iter_batches(
        self,
        batch_size: int,
        device: "torch.device",
        shuffle: bool = True,
    ) -> "Iterator[Tensor]":
        """Iterate over batches of baseline samples.

        Provides memory-efficient iteration over the dataset in batches.
        The last batch may be smaller than batch_size.

        Each batch is centered independently (mean = 0 along batch dimension).

        Args:
            batch_size: Number of samples per batch.
            device: Device for output tensors.
            shuffle: Whether to shuffle data before batching.

        Yields:
            Batches of shape [batch_size, K] (last batch may be smaller).

        Raises:
            RuntimeError: If provider not initialized.
        """
        ...

    @property
    def supports_streaming(self) -> bool:
        """Check if streaming mode is enabled.

        Returns:
            True if provider is configured for streaming access.
        """
        ...


# Type aliases for convenience
ProviderType = BaselineProvider | CacheableProvider | StreamingProvider


__all__ = [
    "BaselineProvider",
    "CacheableProvider",
    "StreamingProvider",
    "ProviderType",
]
