"""Base implementation for baseline providers.

This module provides the abstract base class for all baseline providers.
Subclasses must implement the abstract methods for initialization,
validation, and raw sample retrieval.

The BaseProvider class handles:
- Lazy initialization pattern
- Automatic centering of samples (E[z'] = 0)
- Device management
- Common iteration and length protocols
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Iterator

import torch
from torch import Tensor

from expected_gradcam.sampling.utils import center_samples

if TYPE_CHECKING:
    from torch import nn


class BaseProvider(ABC):
    """Abstract base class for baseline providers.

    Provides common functionality:
    - Lazy initialization: Data is loaded only when initialize() is called
    - Automatic centering: All samples are centered to satisfy E[z'] = 0
    - Device management: Samples are returned on the requested device

    Subclasses must implement:
    - provider_type: Return the type identifier
    - _do_initialize(): Provider-specific initialization logic
    - _get_raw_samples(): Get samples before centering
    - __len__(): Return dataset size

    Example subclass::

        @baseline_provider("my_provider")
        class MyProvider(BaseProvider):
            def __init__(self, path: str):
                super().__init__()
                self.path = path

            @property
            def provider_type(self) -> str:
                return "my_provider"

            def _do_initialize(self) -> None:
                self._data = load_data(self.path)
                self._n_channels = self._data.shape[1]

            def _get_raw_samples(self, n: int) -> Tensor:
                indices = torch.randint(0, len(self._data), (n,))
                return self._data[indices]

            def __len__(self) -> int:
                return len(self._data)
    """

    def __init__(self) -> None:
        """Initialize base provider state."""
        self._initialized: bool = False
        self._model: nn.Module | None = None
        self._target_layer: nn.Module | None = None
        self._device: torch.device | None = None
        self._n_channels: int | None = None

    @property
    def is_initialized(self) -> bool:
        """Check if provider is initialized.

        Returns:
            True if initialize() has been called successfully.
        """
        return self._initialized

    @property
    @abstractmethod
    def provider_type(self) -> str:
        """Return the type identifier for this provider.

        Returns:
            Provider type string (e.g., "directory", "huggingface").
        """
        ...

    @property
    def n_channels(self) -> int | None:
        """Return the number of feature channels (K).

        Returns:
            Number of channels, or None if not yet initialized.
        """
        return self._n_channels

    @property
    def device(self) -> torch.device | None:
        """Return the device for this provider.

        Returns:
            Torch device, or None if not yet initialized.
        """
        return self._device

    def initialize(
        self,
        model: "nn.Module",
        target_layer: "nn.Module",
        device: torch.device,
    ) -> None:
        """Initialize the provider with model context.

        This method is called lazily when baselines are first needed.
        It stores the model context and calls _do_initialize() for
        provider-specific logic.

        Args:
            model: The CNN model for feature extraction.
            target_layer: Layer from which to extract features.
            device: Computation device.

        Note:
            Calling initialize() multiple times is a no-op. The provider
            is only initialized once.
        """
        if self._initialized:
            return

        self._model = model
        self._target_layer = target_layer
        self._device = device

        self._do_initialize()
        self._initialized = True

    @abstractmethod
    def _do_initialize(self) -> None:
        """Provider-specific initialization logic.

        Subclasses should implement this to:
        1. Load or prepare data
        2. Extract features if needed
        3. Set self._n_channels

        Raises:
            ProviderInitializationError: If initialization fails.
        """
        ...

    def validate(
        self,
        model: "nn.Module",
        target_layer: "nn.Module",
    ) -> None:
        """Validate that provider data is compatible with model.

        Default implementation is a no-op. Subclasses can override
        to add validation logic.

        Args:
            model: The CNN model.
            target_layer: Layer for feature extraction.

        Raises:
            DimensionMismatchError: If dimensions don't match.
            BaselineValidationError: For other validation issues.
        """
        pass

    def get_baseline_samples(
        self,
        n: int,
        device: torch.device,
    ) -> Tensor:
        """Get n centered baseline samples.

        This is the main method for obtaining baselines. It:
        1. Calls _get_raw_samples() to get uncenered samples
        2. Centers the samples to satisfy E[z'] = 0
        3. Moves samples to the requested device

        Args:
            n: Number of samples to return.
            device: Device for output tensor.

        Returns:
            Centered baseline samples [n, K].

        Raises:
            RuntimeError: If provider not initialized.
        """
        if not self._initialized:
            raise RuntimeError(
                f"Provider '{self.provider_type}' not initialized. "
                "Call initialize() first or use auto-initialization via config."
            )

        raw_samples = self._get_raw_samples(n)

        # CRITICAL: Center samples to ensure E[z'] = 0
        centered = center_samples(raw_samples)

        return centered.to(device)

    @abstractmethod
    def _get_raw_samples(self, n: int) -> Tensor:
        """Get n raw samples before centering.

        Subclasses should implement this to return samples from
        their data source. The samples do not need to be centered.

        Args:
            n: Number of samples to return.

        Returns:
            Uncentered samples [n, K].
        """
        ...

    @abstractmethod
    def __len__(self) -> int:
        """Return total number of available samples.

        Returns:
            Number of baseline samples available.
        """
        ...

    def __iter__(self) -> Iterator[Tensor]:
        """Iterate over all baseline samples.

        Yields individual baseline vectors of shape [K].
        Note: Samples are NOT centered individually when iterating.

        Yields:
            Individual baseline vectors [K].
        """
        for i in range(len(self)):
            yield self._get_raw_samples(1).squeeze(0)

    def __repr__(self) -> str:
        """Return string representation of provider.

        Returns:
            String representation.
        """
        parts = [f"{self.__class__.__name__}("]
        parts.append(f"type='{self.provider_type}'")
        parts.append(f", initialized={self._initialized}")
        if self._n_channels is not None:
            parts.append(f", channels={self._n_channels}")
        parts.append(")")
        return "".join(parts)


class CacheableBaseProvider(BaseProvider):
    """Base class for providers that support feature caching.

    Extends BaseProvider with methods for saving and loading
    pre-extracted features to/from disk.

    Subclasses should implement:
    - _save_cache_data(): Save internal cache to path
    - _load_cache_data(): Load cache from path
    """

    def __init__(self) -> None:
        """Initialize cacheable provider."""
        super().__init__()
        self._cache_path: str | None = None

    @property
    def cache_path(self) -> str | None:
        """Get current cache path.

        Returns:
            Path to cache file, or None if not set.
        """
        return self._cache_path

    def save_cache(self, cache_path: str) -> None:
        """Save extracted features to cache file.

        Args:
            cache_path: Path to save cache (typically .npy).

        Raises:
            RuntimeError: If provider not initialized.
        """
        if not self._initialized:
            raise RuntimeError("Provider must be initialized before saving cache.")

        self._save_cache_data(cache_path)
        self._cache_path = cache_path

    def load_cache(self, cache_path: str) -> bool:
        """Load features from cache file.

        If cache exists and is valid, loads features and marks
        provider as initialized.

        Args:
            cache_path: Path to cache file.

        Returns:
            True if cache loaded successfully, False otherwise.
        """
        try:
            success = self._load_cache_data(cache_path)
            if success:
                self._cache_path = cache_path
                self._initialized = True
            return success
        except Exception:
            return False

    def _save_cache_data(self, cache_path: str) -> None:
        """Save internal cache to path.

        Subclasses should implement this method.

        Args:
            cache_path: Path to save cache.
        """
        raise NotImplementedError("Subclass must implement _save_cache_data")

    def _load_cache_data(self, cache_path: str) -> bool:
        """Load cache from path.

        Subclasses should implement this method.

        Args:
            cache_path: Path to load from.

        Returns:
            True if load successful.
        """
        raise NotImplementedError("Subclass must implement _load_cache_data")


__all__ = [
    "BaseProvider",
    "CacheableBaseProvider",
]
