"""Feature extraction and caching utilities.

This module provides utilities for efficient batch feature extraction
and caching. Pre-extracted features can significantly speed up repeated
Expected GradCAM computations.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

if TYPE_CHECKING:
    from numpy.typing import NDArray


class BatchedFeatureExtractor:
    """Extract features from a dataset in batches.

    This class efficiently extracts feature maps from a target layer
    for all images in a dataset. The extracted features can be used
    for baseline sampling in Expected Gradients.

    Attributes:
        model: The model to extract features from.
        target_layer: The layer to extract features from.
        device: Device for computation.

    Example:
        >>> extractor = BatchedFeatureExtractor(model, model.layer4)
        >>> loader = DataLoader(dataset, batch_size=32)
        >>> features = extractor.extract_all(loader)
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        device: torch.device | str | None = None,
    ) -> None:
        """Initialize feature extractor.

        Args:
            model: The model to extract features from.
            target_layer: The layer to extract features from.
            device: Device for computation.
        """
        self.model = model
        self.target_layer = target_layer

        if device is None:
            device = next(model.parameters()).device
        elif isinstance(device, str):
            device = torch.device(device)
        self.device = device

        # Hook state
        self._features: Tensor | None = None
        self._hook = None

        # Register hook
        self._register_hook()

    def _register_hook(self) -> None:
        """Register forward hook on target layer."""

        def hook_fn(module: nn.Module, input: tuple, output: Tensor) -> None:
            self._features = output.detach()

        self._hook = self.target_layer.register_forward_hook(hook_fn)

    @torch.no_grad()
    def extract_batch(self, images: Tensor) -> Tensor:
        """Extract features from a batch of images.

        Args:
            images: Batch of images [B, C, H, W].

        Returns:
            Feature maps [B, K, U, V].
        """
        self.model.eval()
        images = images.to(self.device)

        # Forward pass triggers hook
        _ = self.model(images)

        return self._features

    @torch.no_grad()
    def extract_all(
        self,
        dataloader: DataLoader,
        max_batches: int | None = None,
        progress: bool = True,
    ) -> Tensor:
        """Extract features from entire dataset.

        Args:
            dataloader: DataLoader for the dataset.
            max_batches: Maximum number of batches to process.
            progress: Whether to show progress bar.

        Returns:
            All features concatenated [N, K, U, V].
        """
        all_features = []

        # Optional progress bar
        if progress:
            try:
                from tqdm import tqdm

                iterator = tqdm(dataloader, desc="Extracting features")
            except ImportError:
                iterator = dataloader
        else:
            iterator = dataloader

        for batch_idx, batch in enumerate(iterator):
            if max_batches is not None and batch_idx >= max_batches:
                break

            # Handle different batch formats
            if isinstance(batch, (list, tuple)):
                images = batch[0]
            else:
                images = batch

            features = self.extract_batch(images)
            all_features.append(features.cpu())

        return torch.cat(all_features, dim=0)

    def close(self) -> None:
        """Remove hook."""
        if self._hook is not None:
            self._hook.remove()
            self._hook = None

    def __del__(self) -> None:
        self.close()


class FeatureCache:
    """Cache for pre-extracted features.

    This class manages a cache of pre-extracted feature maps that can
    be used for baseline sampling. Features are stored as memory-mapped
    numpy arrays for efficient access.

    Attributes:
        cache_path: Path to the cache file.
        features: Memory-mapped feature array.

    Example:
        >>> # Create cache from dataset
        >>> cache = FeatureCache.create(
        ...     "features.npy", extractor, dataloader
        ... )
        >>>
        >>> # Load existing cache
        >>> cache = FeatureCache("features.npy")
        >>> baselines = cache.sample(n=20)
    """

    def __init__(
        self,
        cache_path: str | Path,
    ) -> None:
        """Initialize feature cache.

        Args:
            cache_path: Path to the cache file (.npy or .npz).
        """
        self.cache_path = Path(cache_path)

        if not self.cache_path.exists():
            raise FileNotFoundError(f"Cache not found: {cache_path}")

        # Load as memory-mapped
        self._features = np.load(self.cache_path, mmap_mode="r")

        # Handle NPZ files
        if isinstance(self._features, np.lib.npyio.NpzFile):
            self._features = self._features["features"]

    @property
    def shape(self) -> tuple[int, ...]:
        """Shape of cached features [N, K, U, V]."""
        return self._features.shape

    @property
    def n_samples(self) -> int:
        """Number of cached samples."""
        return self._features.shape[0]

    @property
    def n_channels(self) -> int:
        """Number of feature channels."""
        return self._features.shape[1]

    def get(self, indices: list[int] | "NDArray[np.int64]") -> "NDArray[np.floating]":
        """Get features at specific indices.

        Args:
            indices: Indices of features to get.

        Returns:
            Feature arrays [len(indices), K, U, V].
        """
        return np.array(self._features[indices])

    def sample(
        self,
        n: int,
        replace: bool = False,
        rng: np.random.Generator | None = None,
    ) -> "NDArray[np.floating]":
        """Sample random features.

        Args:
            n: Number of features to sample.
            replace: Whether to sample with replacement.
            rng: Random number generator.

        Returns:
            Sampled features [n, K, U, V].
        """
        if rng is None:
            rng = np.random.default_rng()

        indices = rng.choice(self.n_samples, size=n, replace=replace)
        return self.get(indices)

    def sample_tensor(
        self,
        n: int,
        device: torch.device | str = "cpu",
        replace: bool = False,
    ) -> Tensor:
        """Sample random features as PyTorch tensor.

        Args:
            n: Number of features to sample.
            device: Device for tensor.
            replace: Whether to sample with replacement.

        Returns:
            Sampled features [n, K, U, V].
        """
        features = self.sample(n, replace=replace)
        return torch.from_numpy(features).to(device)

    @classmethod
    def create(
        cls,
        cache_path: str | Path,
        extractor: BatchedFeatureExtractor,
        dataloader: DataLoader,
        max_batches: int | None = None,
        compress: bool = False,
    ) -> "FeatureCache":
        """Create cache from dataset.

        Args:
            cache_path: Path to save cache.
            extractor: Feature extractor to use.
            dataloader: DataLoader for the dataset.
            max_batches: Maximum number of batches to process.
            compress: Whether to use compression (slower but smaller).

        Returns:
            Created FeatureCache instance.
        """
        # Extract all features
        features = extractor.extract_all(dataloader, max_batches)
        features_np = features.numpy()

        # Save to disk
        cache_path = Path(cache_path)
        if compress or cache_path.suffix == ".npz":
            np.savez_compressed(cache_path, features=features_np)
        else:
            np.save(cache_path, features_np)

        return cls(cache_path)
