"""PyTorch Dataset wrapper baseline provider.

Wraps existing PyTorch datasets to extract GAP features.
This allows using any custom dataset with the baseline provider system.
"""

from __future__ import annotations

import random
from typing import TYPE_CHECKING, Any

import torch
from torch import Tensor

from expected_gradcam.baselines.base import CacheableBaseProvider
from expected_gradcam.baselines.registry import baseline_provider
from expected_gradcam.exceptions.baseline import (
    EmptyBaselineDatasetError,
    ProviderInitializationError,
)
from expected_gradcam.hooks import FeatureMapHook

if TYPE_CHECKING:
    from torch import nn
    from torch.utils.data import Dataset

import numpy as np


@baseline_provider(
    "torch_dataset",
    full_name="PyTorch Dataset Provider",
    description="Wrap existing PyTorch datasets for baseline extraction",
    aliases=("dataset", "torch"),
    supports_caching=True,
    supports_streaming=False,
)
class TorchDatasetProvider(CacheableBaseProvider):
    """Baseline provider that wraps PyTorch datasets.

    Takes any PyTorch Dataset and extracts GAP features from it.
    This enables using custom datasets with the baseline system.

    The dataset should return either:
    - (image_tensor, label) tuples
    - Just image tensors

    Attributes:
        dataset: The PyTorch dataset to wrap.
        max_samples: Maximum number of samples to use.
        batch_size: Batch size for feature extraction.

    Example::

        from torchvision.datasets import CIFAR10

        dataset = CIFAR10(root="/data", train=True, transform=transform)
        provider = TorchDatasetProvider(dataset, max_samples=1000)
        provider.initialize(model, layer, device)
        samples = provider.get_baseline_samples(n=20, device=device)
    """

    def __init__(
        self,
        dataset: "Dataset",
        max_samples: int | None = None,
        batch_size: int = 32,
        cache_path: str | None = None,
    ) -> None:
        """Initialize PyTorch dataset provider.

        Args:
            dataset: PyTorch Dataset to wrap. Must return tensors
                or (tensor, label) tuples.
            max_samples: Maximum number of samples to use.
            batch_size: Batch size for feature extraction.
            cache_path: Optional path to save/load feature cache.
        """
        super().__init__()

        self.dataset = dataset
        self.max_samples = max_samples
        self.batch_size = batch_size
        self._cache_path = cache_path

        # Internal state
        self._gap_cache: Tensor | None = None
        self._n_samples: int = 0

    @property
    def provider_type(self) -> str:
        """Return provider type identifier."""
        return "torch_dataset"

    def _do_initialize(self) -> None:
        """Extract features from dataset."""
        # Check if we can load from cache
        if self._cache_path and self._load_cache_data(self._cache_path):
            return

        dataset_size = len(self.dataset)
        if dataset_size == 0:
            raise EmptyBaselineDatasetError(
                source="torch_dataset",
                reason="Dataset is empty",
            )

        # Determine number of samples to use
        self._n_samples = dataset_size
        if self.max_samples is not None:
            self._n_samples = min(self._n_samples, self.max_samples)

        # Random indices
        indices = list(range(dataset_size))
        random.shuffle(indices)
        indices = indices[: self._n_samples]

        # Extract features
        self._extract_features(indices)

        # Save cache if path provided
        if self._cache_path:
            self._save_cache_data(self._cache_path)

    def _extract_features(self, indices: list[int]) -> None:
        """Extract GAP features from dataset samples."""
        gap_values = []
        num_batches = (len(indices) + self.batch_size - 1) // self.batch_size

        with FeatureMapHook(self._target_layer) as hook:
            for batch_idx in range(num_batches):
                start = batch_idx * self.batch_size
                end = min(start + self.batch_size, len(indices))
                batch_indices = indices[start:end]

                # Load batch
                batch_images = []
                for idx in batch_indices:
                    try:
                        item = self.dataset[idx]

                        # Handle different return types
                        if isinstance(item, tuple):
                            image = item[0]  # (image, label)
                        else:
                            image = item

                        if image.dim() == 3:
                            image = image.unsqueeze(0)

                        batch_images.append(image)
                    except Exception:
                        continue

                if not batch_images:
                    continue

                batch_tensor = torch.cat(batch_images, dim=0).to(self._device)

                # Forward pass
                with torch.no_grad():
                    _ = self._model(batch_tensor)

                # Compute GAP
                features = hook.features  # [B, K, U, V]
                gap = features.mean(dim=(2, 3))  # [B, K]
                gap_values.append(gap.cpu())

        if not gap_values:
            raise ProviderInitializationError(
                "torch_dataset",
                reason="Failed to extract features from any samples",
            )

        self._gap_cache = torch.cat(gap_values, dim=0)  # [N, K]
        self._n_channels = self._gap_cache.shape[1]
        self._n_samples = len(self._gap_cache)

    def _get_raw_samples(self, n: int) -> Tensor:
        """Sample from GAP cache."""
        if self._gap_cache is None:
            raise RuntimeError("GAP cache not initialized")

        cache_size = len(self._gap_cache)

        # Sample with replacement if n > cache_size
        if n <= cache_size:
            indices = torch.randperm(cache_size)[:n]
        else:
            indices = torch.randint(0, cache_size, (n,))

        return self._gap_cache[indices]

    def __len__(self) -> int:
        """Return number of cached samples."""
        if self._gap_cache is not None:
            return len(self._gap_cache)
        return self._n_samples

    def _save_cache_data(self, cache_path: str) -> None:
        """Save GAP cache to file."""
        if self._gap_cache is None:
            raise RuntimeError("No cache data to save")

        np.save(cache_path, self._gap_cache.numpy())

    def _load_cache_data(self, cache_path: str) -> bool:
        """Load GAP cache from file."""
        try:
            from pathlib import Path

            if not Path(cache_path).exists():
                return False

            data = np.load(cache_path)
            self._gap_cache = torch.from_numpy(data)
            self._n_channels = self._gap_cache.shape[1]
            self._n_samples = len(self._gap_cache)
            return True
        except Exception:
            return False


__all__ = ["TorchDatasetProvider"]
