"""HuggingFace datasets baseline provider.

Loads baseline images from HuggingFace datasets.
Requires the 'datasets' package: pip install datasets
"""

from __future__ import annotations

import random
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
from torch import Tensor

from expected_gradcam.baselines.base import CacheableBaseProvider
from expected_gradcam.baselines.registry import baseline_provider
from expected_gradcam.exceptions.baseline import (
    EmptyBaselineDatasetError,
    HuggingFaceLoadError,
    ProviderInitializationError,
)
from expected_gradcam.hooks import FeatureMapHook

if TYPE_CHECKING:
    from torch import nn


def _default_imagenet_transform() -> Callable[[Any], Tensor]:
    """Get default ImageNet preprocessing transform."""
    try:
        from torchvision import transforms

        return transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )
    except ImportError:
        raise ProviderInitializationError(
            "huggingface",
            reason="torchvision required for image transforms. "
            "Install with: pip install torchvision",
        )


@baseline_provider(
    "huggingface",
    full_name="HuggingFace Datasets Provider",
    description="Load baselines from HuggingFace datasets",
    aliases=("hf", "datasets"),
    supports_caching=True,
    supports_streaming=True,
    requires_packages=("datasets",),
)
class HuggingFaceProvider(CacheableBaseProvider):
    """Baseline provider for HuggingFace datasets.

    Loads images from HuggingFace Hub datasets. Supports both
    regular and streaming mode for large datasets.

    Requires the datasets package: pip install datasets

    Attributes:
        dataset_name: Name of the HuggingFace dataset.
        split: Dataset split to use.
        image_column: Column name containing images.
        streaming: Whether to use streaming mode.
        max_samples: Maximum number of samples to load.

    Example::

        provider = HuggingFaceProvider(
            dataset_name="imagenet-1k",
            split="train",
            max_samples=1000,
        )
        provider.initialize(model, layer, device)
        samples = provider.get_baseline_samples(n=20, device=device)

    Note:
        Some datasets (like imagenet-1k) require authentication.
        Run `huggingface-cli login` first.
    """

    def __init__(
        self,
        dataset_name: str,
        split: str = "train",
        image_column: str = "image",
        streaming: bool = False,
        max_samples: int | None = 1000,
        trust_remote_code: bool = False,
        transform: Callable[[Any], Tensor] | None = None,
        cache_path: str | None = None,
    ) -> None:
        """Initialize HuggingFace provider.

        Args:
            dataset_name: HuggingFace dataset name (e.g., "imagenet-1k").
            split: Dataset split ("train", "validation", "test").
            image_column: Name of the image column in the dataset.
            streaming: Use streaming mode for large datasets.
            max_samples: Maximum number of samples to use.
            trust_remote_code: Whether to trust remote code in datasets.
            transform: Image preprocessing transform.
            cache_path: Optional path to save/load feature cache.
        """
        super().__init__()

        self.dataset_name = dataset_name
        self.split = split
        self.image_column = image_column
        self.streaming = streaming
        self.max_samples = max_samples
        self.trust_remote_code = trust_remote_code
        self.transform = transform
        self._cache_path = cache_path

        # Internal state
        self._dataset: Any = None
        self._gap_cache: Tensor | None = None

    @property
    def provider_type(self) -> str:
        """Return provider type identifier."""
        return "huggingface"

    @property
    def supports_streaming(self) -> bool:
        """Check if streaming mode is enabled."""
        return self.streaming

    def _do_initialize(self) -> None:
        """Load dataset and extract features."""
        # Check if we can load from cache
        if self._cache_path and self._load_cache_data(self._cache_path):
            return

        # Import datasets
        try:
            from datasets import load_dataset
        except ImportError:
            raise ProviderInitializationError(
                "huggingface",
                reason="datasets package required. Install with: pip install datasets",
            )

        # Load dataset
        try:
            self._dataset = load_dataset(
                self.dataset_name,
                split=self.split,
                streaming=self.streaming,
                trust_remote_code=self.trust_remote_code,
            )
        except Exception as e:
            error_str = str(e).lower()
            requires_auth = "authentication" in error_str or "login" in error_str

            raise HuggingFaceLoadError(
                self.dataset_name,
                reason=str(e),
                split=self.split,
                requires_auth=requires_auth,
                cause=e,
            )

        # Extract features
        self._extract_features()

        # Save cache if path provided
        if self._cache_path:
            self._save_cache_data(self._cache_path)

    def _extract_features(self) -> None:
        """Extract GAP features from HuggingFace dataset."""
        if self.transform is None:
            self.transform = _default_imagenet_transform()

        gap_values = []
        count = 0
        failed_count = 0

        with FeatureMapHook(self._target_layer) as hook:
            for sample in self._dataset:
                if self.max_samples and count >= self.max_samples:
                    break

                try:
                    # Get image
                    img = sample[self.image_column]

                    # Handle PIL images
                    if hasattr(img, "convert"):
                        img = img.convert("RGB")

                    # Apply transform
                    x = self.transform(img)
                    if x.dim() == 3:
                        x = x.unsqueeze(0)
                    x = x.to(self._device)

                    # Forward pass
                    with torch.no_grad():
                        _ = self._model(x)

                    # Compute GAP
                    features = hook.features
                    gap = features.mean(dim=(2, 3)).squeeze(0)
                    gap_values.append(gap.cpu())
                    count += 1

                except Exception:
                    failed_count += 1
                    if failed_count > 100:
                        # Too many failures, probably wrong column name
                        raise ProviderInitializationError(
                            "huggingface",
                            reason=f"Failed to process 100+ samples. "
                            f"Check image_column='{self.image_column}' is correct.",
                        )
                    continue

        if not gap_values:
            raise EmptyBaselineDatasetError(
                source=f"huggingface://{self.dataset_name}",
                reason="Failed to extract features from any samples",
            )

        self._gap_cache = torch.stack(gap_values)
        self._n_channels = self._gap_cache.shape[1]

        if failed_count > 0:
            import warnings

            warnings.warn(
                f"HuggingFaceProvider: {failed_count} samples failed to process. "
                f"Using {len(gap_values)} valid samples.",
                stacklevel=2,
            )

    def _get_raw_samples(self, n: int) -> Tensor:
        """Sample from GAP cache."""
        if self._gap_cache is None:
            raise RuntimeError("GAP cache not initialized")

        cache_size = len(self._gap_cache)

        if n <= cache_size:
            indices = torch.randperm(cache_size)[:n]
        else:
            indices = torch.randint(0, cache_size, (n,))

        return self._gap_cache[indices]

    def __len__(self) -> int:
        """Return number of cached samples."""
        if self._gap_cache is not None:
            return len(self._gap_cache)
        return self.max_samples or 0

    def _save_cache_data(self, cache_path: str) -> None:
        """Save GAP cache to file."""
        if self._gap_cache is None:
            raise RuntimeError("No cache data to save")

        np.save(cache_path, self._gap_cache.numpy())

    def _load_cache_data(self, cache_path: str) -> bool:
        """Load GAP cache from file."""
        try:
            from pathlib import Path

            if not Path(cache_path).exists():
                return False

            data = np.load(cache_path)
            self._gap_cache = torch.from_numpy(data)
            self._n_channels = self._gap_cache.shape[1]
            return True
        except Exception:
            return False

    def iter_batches(
        self,
        batch_size: int,
        device: torch.device,
        shuffle: bool = True,
    ):
        """Iterate over batches of baseline samples.

        Args:
            batch_size: Number of samples per batch.
            device: Device for output tensors.
            shuffle: Whether to shuffle (ignored for streaming).

        Yields:
            Batches of shape [batch_size, K].
        """
        if self._gap_cache is None:
            raise RuntimeError("Provider not initialized")

        n_samples = len(self._gap_cache)
        indices = list(range(n_samples))

        if shuffle:
            random.shuffle(indices)

        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            batch_indices = indices[start:end]
            batch = self._gap_cache[batch_indices]

            # Center batch
            from expected_gradcam.sampling.utils import center_samples

            batch = center_samples(batch)

            yield batch.to(device)


__all__ = ["HuggingFaceProvider"]
