"""Directory-based baseline provider.

Scans a directory for images and extracts GAP features during initialization.
Supports caching for faster subsequent loads.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
from torch import Tensor

from expected_gradcam.baselines.base import CacheableBaseProvider
from expected_gradcam.baselines.registry import baseline_provider
from expected_gradcam.exceptions.baseline import (
    DirectoryNotFoundError,
    EmptyBaselineDatasetError,
    InvalidBaselineImageError,
    ProviderInitializationError,
)
from expected_gradcam.hooks import FeatureMapHook

if TYPE_CHECKING:
    from torch import nn


def _default_imagenet_transform() -> Callable[[Any], Tensor]:
    """Get default ImageNet preprocessing transform.

    Returns:
        Compose transform for ImageNet preprocessing.

    Raises:
        ProviderInitializationError: If torchvision not installed.
    """
    try:
        from torchvision import transforms

        return transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )
    except ImportError:
        raise ProviderInitializationError(
            "directory",
            reason="torchvision required for default transform. "
            "Install with: pip install torchvision",
        )


@baseline_provider(
    "directory",
    full_name="Directory Provider",
    description="Load baseline images from a directory of image files",
    aliases=("dir", "folder", "path"),
    supports_caching=True,
    supports_streaming=False,
    default=True,
)
class DirectoryProvider(CacheableBaseProvider):
    """Baseline provider that loads images from a directory.

    Scans the specified directory for image files, extracts GAP features
    during initialization, and caches them for efficient sampling.

    Attributes:
        path: Path to the image directory.
        extensions: File extensions to include.
        max_images: Maximum number of images to load.
        transform: Image preprocessing transform.
        shuffle: Whether to shuffle images before loading.

    Example::

        provider = DirectoryProvider(
            path="/data/imagenet/train",
            extensions=(".jpg", ".png"),
            max_images=1000,
        )
        provider.initialize(model, layer, device)
        samples = provider.get_baseline_samples(n=20, device=device)
    """

    def __init__(
        self,
        path: str | Path,
        extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".JPEG"),
        max_images: int | None = None,
        transform: Callable[[Any], Tensor] | None = None,
        cache_path: str | Path | None = None,
        shuffle: bool = True,
    ) -> None:
        """Initialize directory provider.

        Args:
            path: Path to directory containing images.
            extensions: File extensions to search for.
            max_images: Maximum number of images to use.
            transform: Image preprocessing transform. If None, uses
                default ImageNet preprocessing.
            cache_path: Optional path to save/load feature cache.
            shuffle: Whether to shuffle images before loading.
        """
        super().__init__()

        self.path = Path(path)
        self.extensions = extensions
        self.max_images = max_images
        self.transform = transform
        self._cache_path = str(cache_path) if cache_path else None
        self.shuffle = shuffle

        # Internal state
        self._image_paths: list[Path] = []
        self._gap_cache: Tensor | None = None

    @property
    def provider_type(self) -> str:
        """Return provider type identifier."""
        return "directory"

    def _do_initialize(self) -> None:
        """Initialize by discovering images and extracting features."""
        # Check if we can load from cache
        if self._cache_path and Path(self._cache_path).exists():
            if self._load_cache_data(self._cache_path):
                return

        # Check directory exists
        if not self.path.exists():
            raise DirectoryNotFoundError(self.path)

        if self.path.is_file():
            raise DirectoryNotFoundError(self.path, is_file=True)

        # Discover images
        self._image_paths = []
        for ext in self.extensions:
            # Search both with and without leading dot
            if ext.startswith("."):
                self._image_paths.extend(self.path.rglob(f"*{ext}"))
            else:
                self._image_paths.extend(self.path.rglob(f"*.{ext}"))

        if not self._image_paths:
            raise EmptyBaselineDatasetError(
                source=self.path,
                reason=f"No images found with extensions {self.extensions}",
                extensions=self.extensions,
            )

        # Shuffle if requested
        if self.shuffle:
            import random

            random.shuffle(self._image_paths)

        # Limit number of images
        if self.max_images is not None:
            self._image_paths = self._image_paths[: self.max_images]

        # Extract features
        self._extract_features()

        # Save cache if path provided
        if self._cache_path:
            self._save_cache_data(self._cache_path)

    def _extract_features(self) -> None:
        """Extract GAP features from all images."""
        from PIL import Image

        if self.transform is None:
            self.transform = _default_imagenet_transform()

        gap_values = []
        failed_images = []

        with FeatureMapHook(self._target_layer) as hook:
            for img_path in self._image_paths:
                try:
                    img = Image.open(img_path).convert("RGB")
                    x = self.transform(img).unsqueeze(0).to(self._device)

                    with torch.no_grad():
                        _ = self._model(x)

                    # Compute GAP
                    features = hook.features
                    gap = features.mean(dim=(2, 3)).squeeze(0)  # [K]
                    gap_values.append(gap.cpu())

                except Exception as e:
                    failed_images.append((img_path, str(e)))
                    continue

        if not gap_values:
            raise ProviderInitializationError(
                "directory",
                reason=f"Failed to extract features from any images. "
                f"First error: {failed_images[0][1] if failed_images else 'unknown'}",
            )

        self._gap_cache = torch.stack(gap_values)  # [N, K]
        self._n_channels = self._gap_cache.shape[1]

        if failed_images:
            # Log warning but continue
            import warnings

            warnings.warn(
                f"DirectoryProvider: {len(failed_images)} images failed to load. "
                f"Using {len(gap_values)} valid images.",
                stacklevel=2,
            )

    def _get_raw_samples(self, n: int) -> Tensor:
        """Sample from GAP cache."""
        if self._gap_cache is None:
            raise RuntimeError("GAP cache not initialized")

        cache_size = len(self._gap_cache)

        # Sample with replacement if n > cache_size
        if n <= cache_size:
            indices = torch.randperm(cache_size)[:n]
        else:
            indices = torch.randint(0, cache_size, (n,))

        return self._gap_cache[indices]

    def __len__(self) -> int:
        """Return number of cached samples."""
        if self._gap_cache is not None:
            return len(self._gap_cache)
        return len(self._image_paths)

    def _save_cache_data(self, cache_path: str) -> None:
        """Save GAP cache to file."""
        if self._gap_cache is None:
            raise RuntimeError("No cache data to save")

        np.save(cache_path, self._gap_cache.numpy())

    def _load_cache_data(self, cache_path: str) -> bool:
        """Load GAP cache from file."""
        try:
            cache_path_obj = Path(cache_path)
            if not cache_path_obj.exists():
                return False

            data = np.load(cache_path)
            self._gap_cache = torch.from_numpy(data)
            self._n_channels = self._gap_cache.shape[1]
            return True
        except Exception:
            return False


__all__ = ["DirectoryProvider"]
