"""ImageNet-specific baseline provider.

Specialized provider for ImageNet directory structure.
Automatically detects train/val splits and handles ImageNet conventions.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
from torch import Tensor

from expected_gradcam.baselines.base import CacheableBaseProvider
from expected_gradcam.baselines.registry import baseline_provider
from expected_gradcam.exceptions.baseline import (
    DirectoryNotFoundError,
    EmptyBaselineDatasetError,
    ProviderInitializationError,
)
from expected_gradcam.hooks import FeatureMapHook

if TYPE_CHECKING:
    from torch import nn


def _default_imagenet_transform() -> Callable[[Any], Tensor]:
    """Get default ImageNet preprocessing transform."""
    try:
        from torchvision import transforms

        return transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )
    except ImportError:
        raise ProviderInitializationError(
            "imagenet",
            reason="torchvision required for ImageNet transform. "
            "Install with: pip install torchvision",
        )


@baseline_provider(
    "imagenet",
    full_name="ImageNet Provider",
    description="Load from ImageNet directory structure (ILSVRC format)",
    aliases=("ilsvrc",),
    supports_caching=True,
    supports_streaming=False,
)
class ImageNetProvider(CacheableBaseProvider):
    """Baseline provider specialized for ImageNet.

    Handles the standard ImageNet directory structure:
    ```
    imagenet/
    ├── train/
    │   ├── n01440764/
    │   │   ├── n01440764_10026.JPEG
    │   │   └── ...
    │   └── ...
    └── val/
        ├── n01440764/
        └── ...
    ```

    Automatically samples from class subdirectories with balanced
    sampling across classes (optional).

    Attributes:
        root: Root ImageNet directory.
        split: Dataset split ("train" or "val").
        max_images: Maximum number of images to load.
        balanced: Whether to sample equally from each class.

    Example::

        provider = ImageNetProvider(
            root="/data/imagenet",
            split="train",
            max_images=1000,
        )
        provider.initialize(model, layer, device)
        samples = provider.get_baseline_samples(n=20, device=device)
    """

    def __init__(
        self,
        root: str | Path,
        split: str = "train",
        max_images: int | None = 1000,
        balanced: bool = False,
        transform: Callable[[Any], Tensor] | None = None,
        cache_path: str | Path | None = None,
    ) -> None:
        """Initialize ImageNet provider.

        Args:
            root: Root ImageNet directory (containing train/val).
            split: Dataset split ("train" or "val").
            max_images: Maximum number of images to use.
            balanced: If True, sample equally from each class.
            transform: Image preprocessing transform.
            cache_path: Optional path to save/load feature cache.
        """
        super().__init__()

        self.root = Path(root)
        self.split = split
        self.max_images = max_images
        self.balanced = balanced
        self.transform = transform
        self._cache_path = str(cache_path) if cache_path else None

        # Internal state
        self._image_paths: list[Path] = []
        self._gap_cache: Tensor | None = None

    @property
    def provider_type(self) -> str:
        """Return provider type identifier."""
        return "imagenet"

    def _do_initialize(self) -> None:
        """Initialize by discovering images and extracting features."""
        # Check if we can load from cache
        if self._cache_path and Path(self._cache_path).exists():
            if self._load_cache_data(self._cache_path):
                return

        # Determine split directory
        split_dir = self.root / self.split
        if not split_dir.exists():
            # Try without split (maybe user provided split path directly)
            if self.root.exists() and any(self.root.iterdir()):
                split_dir = self.root
            else:
                raise DirectoryNotFoundError(
                    split_dir,
                    searched_paths=[self.root / "train", self.root / "val", self.root],
                )

        # Discover class directories
        class_dirs = [d for d in split_dir.iterdir() if d.is_dir()]
        if not class_dirs:
            raise EmptyBaselineDatasetError(
                source=split_dir,
                reason="No class subdirectories found (expected ImageNet format)",
            )

        # Collect images
        self._image_paths = []
        images_per_class = []

        for class_dir in class_dirs:
            class_images = list(class_dir.glob("*.JPEG"))
            class_images.extend(class_dir.glob("*.jpeg"))
            class_images.extend(class_dir.glob("*.jpg"))
            class_images.extend(class_dir.glob("*.png"))

            if class_images:
                images_per_class.append(class_images)

        if not images_per_class:
            raise EmptyBaselineDatasetError(
                source=split_dir,
                reason="No images found in class directories",
            )

        # Sample images
        if self.balanced and self.max_images:
            # Balanced sampling across classes
            import random

            num_classes = len(images_per_class)
            images_per_class_limit = max(1, self.max_images // num_classes)

            for class_images in images_per_class:
                random.shuffle(class_images)
                self._image_paths.extend(class_images[:images_per_class_limit])

            random.shuffle(self._image_paths)
            self._image_paths = self._image_paths[: self.max_images]
        else:
            # Random sampling
            import random

            for class_images in images_per_class:
                self._image_paths.extend(class_images)

            random.shuffle(self._image_paths)
            if self.max_images:
                self._image_paths = self._image_paths[: self.max_images]

        # Extract features
        self._extract_features()

        # Save cache if path provided
        if self._cache_path:
            self._save_cache_data(self._cache_path)

    def _extract_features(self) -> None:
        """Extract GAP features from ImageNet images."""
        from PIL import Image

        if self.transform is None:
            self.transform = _default_imagenet_transform()

        gap_values = []
        failed_count = 0

        with FeatureMapHook(self._target_layer) as hook:
            for img_path in self._image_paths:
                try:
                    img = Image.open(img_path).convert("RGB")
                    x = self.transform(img).unsqueeze(0).to(self._device)

                    with torch.no_grad():
                        _ = self._model(x)

                    features = hook.features
                    gap = features.mean(dim=(2, 3)).squeeze(0)
                    gap_values.append(gap.cpu())

                except Exception:
                    failed_count += 1
                    continue

        if not gap_values:
            raise ProviderInitializationError(
                "imagenet",
                reason="Failed to extract features from any ImageNet images",
            )

        self._gap_cache = torch.stack(gap_values)
        self._n_channels = self._gap_cache.shape[1]

        if failed_count > 0:
            import warnings

            warnings.warn(
                f"ImageNetProvider: {failed_count} images failed to load. "
                f"Using {len(gap_values)} valid images.",
                stacklevel=2,
            )

    def _get_raw_samples(self, n: int) -> Tensor:
        """Sample from GAP cache."""
        if self._gap_cache is None:
            raise RuntimeError("GAP cache not initialized")

        cache_size = len(self._gap_cache)

        if n <= cache_size:
            indices = torch.randperm(cache_size)[:n]
        else:
            indices = torch.randint(0, cache_size, (n,))

        return self._gap_cache[indices]

    def __len__(self) -> int:
        """Return number of cached samples."""
        if self._gap_cache is not None:
            return len(self._gap_cache)
        return len(self._image_paths)

    def _save_cache_data(self, cache_path: str) -> None:
        """Save GAP cache to file."""
        if self._gap_cache is None:
            raise RuntimeError("No cache data to save")

        np.save(cache_path, self._gap_cache.numpy())

    def _load_cache_data(self, cache_path: str) -> bool:
        """Load GAP cache from file."""
        try:
            cache_path_obj = Path(cache_path)
            if not cache_path_obj.exists():
                return False

            data = np.load(cache_path)
            self._gap_cache = torch.from_numpy(data)
            self._n_channels = self._gap_cache.shape[1]
            return True
        except Exception:
            return False


__all__ = ["ImageNetProvider"]
