"""Baseline dataset utilities for Expected GradCAM.

This module provides utilities for creating datasets used in baseline
sampling for Expected Gradients. The baseline distribution D should
be data-aware (sampled from real images) for theoretical guarantees.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset

if TYPE_CHECKING:
    from numpy.typing import NDArray
    from PIL import Image


class BaselineDataset(Dataset):
    """Dataset for baseline sampling in Expected Gradients.

    This dataset provides images for extracting feature maps that serve
    as baselines in Expected Gradients computation. The baseline distribution
    should represent the natural variation of feature activations.

    Attributes:
        image_paths: List of image file paths.
        transform: Optional transform to apply to images.

    Example:
        >>> dataset = BaselineDataset.from_directory("imagenet/train/")
        >>> loader = DataLoader(dataset, batch_size=32, shuffle=True)
    """

    def __init__(
        self,
        image_paths: list[str | Path],
        transform: Callable[[Any], Tensor] | None = None,
    ) -> None:
        """Initialize baseline dataset.

        Args:
            image_paths: List of paths to image files.
            transform: Optional transform to apply to images.
        """
        self.image_paths = [Path(p) for p in image_paths]
        self.transform = transform

        if self.transform is None:
            self.transform = self._default_transform()

    @staticmethod
    def _default_transform() -> Callable[[Any], Tensor]:
        """Get default ImageNet transform."""
        try:
            from torchvision import transforms

            return transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                    ),
                ]
            )
        except ImportError:
            # Fallback without torchvision
            def simple_transform(img: "Image.Image") -> Tensor:
                import numpy as np

                arr = np.array(img.resize((224, 224))).astype(np.float32) / 255.0
                arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
                return torch.from_numpy(arr).permute(2, 0, 1)

            return simple_transform

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tensor:
        """Get transformed image.

        Args:
            idx: Index of image to get.

        Returns:
            Transformed image tensor [C, H, W].
        """
        from PIL import Image

        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img

    @classmethod
    def from_directory(
        cls,
        directory: str | Path,
        extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".JPEG"),
        max_images: int | None = None,
        shuffle: bool = True,
        transform: Callable[[Any], Tensor] | None = None,
    ) -> "BaselineDataset":
        """Create dataset from image directory.

        Args:
            directory: Directory containing images.
            extensions: File extensions to include.
            max_images: Maximum number of images to include.
            shuffle: Whether to shuffle the image list.
            transform: Optional transform to apply.

        Returns:
            BaselineDataset instance.
        """
        directory = Path(directory)

        # Find all images
        image_paths = []
        for ext in extensions:
            image_paths.extend(directory.rglob(f"*{ext}"))

        if shuffle:
            import random

            random.shuffle(image_paths)

        if max_images is not None:
            image_paths = image_paths[:max_images]

        return cls(image_paths, transform)

    @classmethod
    def from_imagenet(
        cls,
        imagenet_root: str | Path,
        split: str = "train",
        max_images: int | None = 1000,
        transform: Callable[[Any], Tensor] | None = None,
    ) -> "BaselineDataset":
        """Create dataset from ImageNet directory structure.

        Args:
            imagenet_root: Root directory of ImageNet dataset.
            split: Dataset split ("train" or "val").
            max_images: Maximum number of images to include.
            transform: Optional transform to apply.

        Returns:
            BaselineDataset instance.
        """
        return cls.from_directory(
            Path(imagenet_root) / split,
            max_images=max_images,
            transform=transform,
        )


def create_baseline_dataset(
    source: str | Path | list[str | Path],
    transform: Callable[[Any], Tensor] | None = None,
    max_images: int | None = None,
) -> BaselineDataset:
    """Convenience function to create baseline dataset.

    Args:
        source: Directory path, or list of image paths.
        transform: Optional transform to apply.
        max_images: Maximum number of images.

    Returns:
        BaselineDataset instance.
    """
    if isinstance(source, (str, Path)):
        source = Path(source)
        if source.is_dir():
            return BaselineDataset.from_directory(
                source, max_images=max_images, transform=transform
            )
        else:
            return BaselineDataset([source], transform=transform)
    else:
        paths = list(source)
        if max_images is not None:
            paths = paths[:max_images]
        return BaselineDataset(paths, transform=transform)
