"""Cached feature baseline provider.

Loads pre-extracted GAP features from numpy files.
This is the most efficient provider when features have already been extracted.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import torch
from torch import Tensor

from expected_gradcam.baselines.base import BaseProvider
from expected_gradcam.baselines.registry import baseline_provider
from expected_gradcam.exceptions.baseline import (
    CacheCorruptedError,
    ProviderInitializationError,
)

if TYPE_CHECKING:
    from torch import nn


@baseline_provider(
    "cached",
    full_name="Cached Feature Provider",
    description="Load pre-extracted GAP features from numpy files",
    aliases=("cache", "features", "npy"),
    supports_caching=False,  # Already cached
    supports_streaming=False,
)
class CachedFeatureProvider(BaseProvider):
    """Baseline provider that loads pre-extracted features.

    Loads GAP features from a numpy file (.npy or .npz).
    This is the most efficient provider since no model forward
    passes are needed during initialization.

    The cache file should contain features of shape [N, K] where:
    - N: Number of baseline samples
    - K: Number of feature channels

    Attributes:
        cache_path: Path to the numpy cache file.

    Example::

        # First, extract features and save
        provider = DirectoryProvider("/data/imagenet/train")
        provider.initialize(model, layer, device)
        provider.save_cache("/cache/imagenet_gap.npy")

        # Later, load from cache (fast)
        cached = CachedFeatureProvider("/cache/imagenet_gap.npy")
        cached.initialize(model, layer, device)  # Just loads cache
        samples = cached.get_baseline_samples(n=20, device=device)
    """

    def __init__(
        self,
        cache_path: str | Path,
        mmap_mode: str | None = None,
    ) -> None:
        """Initialize cached feature provider.

        Args:
            cache_path: Path to numpy cache file (.npy or .npz).
            mmap_mode: Memory-map mode for large files.
                Options: None, 'r', 'r+', 'w+', 'c'.
                Use 'r' for read-only memory mapping of large files.
        """
        super().__init__()

        self.cache_path = Path(cache_path)
        self.mmap_mode = mmap_mode

        # Loaded features
        self._features: Tensor | None = None

    @property
    def provider_type(self) -> str:
        """Return provider type identifier."""
        return "cached"

    def _do_initialize(self) -> None:
        """Load features from cache file."""
        if not self.cache_path.exists():
            raise ProviderInitializationError(
                "cached",
                reason=f"Cache file not found: {self.cache_path}",
            )

        try:
            if self.cache_path.suffix == ".npz":
                # Handle compressed format
                with np.load(self.cache_path) as data:
                    # Try common key names
                    for key in ["features", "gap", "data", "arr_0"]:
                        if key in data:
                            arr = data[key]
                            break
                    else:
                        # Use first array
                        arr = data[list(data.keys())[0]]
            else:
                # Regular .npy file
                arr = np.load(self.cache_path, mmap_mode=self.mmap_mode)

            # Validate shape
            if arr.ndim != 2:
                raise CacheCorruptedError(
                    self.cache_path,
                    reason=f"Expected 2D array [N, K], got shape {arr.shape}",
                )

            self._features = torch.from_numpy(arr.astype(np.float32))
            self._n_channels = self._features.shape[1]

        except CacheCorruptedError:
            raise
        except Exception as e:
            raise CacheCorruptedError(
                self.cache_path,
                reason=f"Failed to load cache: {e}",
            )

    def _get_raw_samples(self, n: int) -> Tensor:
        """Sample from loaded features."""
        if self._features is None:
            raise RuntimeError("Features not loaded")

        cache_size = len(self._features)

        # Sample with replacement if n > cache_size
        if n <= cache_size:
            indices = torch.randperm(cache_size)[:n]
        else:
            indices = torch.randint(0, cache_size, (n,))

        return self._features[indices]

    def __len__(self) -> int:
        """Return number of cached features."""
        if self._features is not None:
            return len(self._features)
        return 0

    def validate(
        self,
        model: "nn.Module",
        target_layer: "nn.Module",
    ) -> None:
        """Validate feature dimensions match model.

        Note: For cached features, we can only check that we have
        valid data. The user is responsible for ensuring features
        were extracted with the same model/layer.
        """
        if self._features is None:
            return

        if self._features.shape[1] == 0:
            from expected_gradcam.exceptions.baseline import DimensionMismatchError

            raise DimensionMismatchError(
                expected_shape=(None, ">0"),
                actual_shape=self._features.shape,
                tensor_name="cached_features",
            )


__all__ = ["CachedFeatureProvider"]
