"""Baseline sampling for Expected Gradients.

This module provides samplers for generating baseline distributions D
used in Expected Gradients path integration.

CRITICAL: For Expected Gradients to satisfy the completeness axiom,
the baseline distribution D must be centered: E[z'] = 0.

Mathematical specification (from paper):
    φ^{EG}(g, z_0, I; A, D) = E_{z'~D}[∫₀¹ ∇_z g(z' + t(z_0 - I - z'); A) dt]

Samplers in this module:
- CenteredBaselineSampler: Gaussian/uniform/sphere distributions (simple, no model)
- DataAwareBaselineSampler: Uses pre-cached feature map GAP values
- DataAwareEGBaselineSampler: Extracts baselines from real images via forward passes
"""

from __future__ import annotations

import random
from typing import TYPE_CHECKING, Literal

import torch
from torch import nn

from expected_gradcam.hooks import FeatureMapHook
from expected_gradcam.sampling.utils import center_samples, verify_centered

if TYPE_CHECKING:
    from torch import Tensor
    from torch.utils.data import Dataset


class CenteredBaselineSampler:
    """Sampler for centered baseline distribution D.

    For Expected Gradients to satisfy the completeness axiom, the
    baseline distribution D must be centered: E[z'] = 0.

    This class generates baseline samples from various distributions
    and ensures they are centered.

    Supported distributions:
    - "gaussian": z' ~ N(0, σ²I) - standard Gaussian
    - "uniform": z' ~ U(-a, a) - uniform in hypercube
    - "sphere": z' uniformly distributed on sphere of radius r

    Attributes:
        distribution: Type of distribution to sample from.
        scale: Scale parameter (σ for Gaussian, a for uniform, r for sphere).
        device: Torch device for samples.

    Example:
        >>> sampler = CenteredBaselineSampler("gaussian", scale=0.1)
        >>> baselines = sampler.sample(K=2048, N=20)
        >>> assert baselines.shape == (20, 2048)
        >>> assert baselines.mean(dim=0).abs().max() < 1e-6
    """

    def __init__(
        self,
        distribution: Literal["gaussian", "uniform", "sphere"] = "gaussian",
        scale: float = 0.1,
        device: torch.device | str | None = None,
    ) -> None:
        """Initialize centered baseline sampler.

        Args:
            distribution: Type of distribution to sample from.
            scale: Scale parameter (σ for Gaussian, a for uniform, r for sphere).
            device: Torch device.
        """
        self.distribution = distribution
        self.scale = scale
        if device is None:
            self.device = torch.device("cpu")
        elif isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device

    def sample(self, K: int, N: int) -> Tensor:
        """Sample N centered baseline vectors of dimension K.

        CRITICAL: The returned samples are guaranteed to be centered
        (mean = 0) to ensure completeness axiom holds.

        Args:
            K: Dimensionality (number of feature maps).
            N: Number of samples.

        Returns:
            Centered baseline samples [N, K] with mean 0.
        """
        if self.distribution == "gaussian":
            samples = torch.randn(N, K, device=self.device) * self.scale

        elif self.distribution == "uniform":
            # U(-scale, scale)
            samples = (torch.rand(N, K, device=self.device) * 2 - 1) * self.scale

        elif self.distribution == "sphere":
            # Sample from Gaussian and normalize to sphere
            samples = torch.randn(N, K, device=self.device)
            norms = samples.norm(dim=1, keepdim=True)
            samples = samples / (norms + 1e-10) * self.scale

        else:
            raise ValueError(f"Unknown distribution: {self.distribution}")

        # CRITICAL: Center the samples to ensure E[z'] = 0
        samples = center_samples(samples)

        return samples

    def sample_with_stats(self, K: int, N: int) -> tuple[Tensor, Tensor, Tensor]:
        """Sample baselines and return statistics for verification.

        Args:
            K: Dimensionality.
            N: Number of samples.

        Returns:
            Tuple of (samples [N, K], mean [K], std [K]).
        """
        samples = self.sample(K, N)
        mean = samples.mean(dim=0)
        std = samples.std(dim=0)
        return samples, mean, std


class DataAwareBaselineSampler:
    """Data-aware baseline sampler using pre-computed feature map cache.

    Instead of using random noise, samples baselines from cached
    feature map GAP values extracted from real images.

    This provides baselines that reflect the true distribution of
    feature activations in the data.

    Attributes:
        feature_map_cache: Pre-computed GAP of feature maps [dataset_size, K].
        device: Torch device.

    Example:
        >>> # Pre-compute cache (once)
        >>> cache = extract_gap_cache(model, target_layer, dataset)
        >>> sampler = DataAwareBaselineSampler(cache, device="cuda")
        >>> baselines = sampler.sample(K=2048, N=20)
    """

    def __init__(
        self,
        feature_map_cache: Tensor | None = None,
        device: torch.device | str | None = None,
    ) -> None:
        """Initialize data-aware baseline sampler.

        Args:
            feature_map_cache: Pre-computed GAP of feature maps [dataset_size, K].
                If None, must be set later via set_cache().
            device: Torch device.
        """
        self.feature_map_cache = feature_map_cache
        if device is None:
            self.device = torch.device("cpu")
        elif isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device

    def set_cache(self, cache: Tensor) -> None:
        """Set the feature map cache.

        Args:
            cache: Pre-computed GAP values [dataset_size, K].
        """
        self.feature_map_cache = cache.to(self.device)

    def sample(self, K: int, N: int) -> Tensor:
        """Sample baselines from cached feature maps.

        Samples N feature map vectors from the cache and centers them.

        Args:
            K: Dimensionality (must match cache).
            N: Number of samples.

        Returns:
            Centered baseline samples [N, K].

        Raises:
            ValueError: If feature map cache not set.
        """
        if self.feature_map_cache is None:
            raise ValueError("Feature map cache not set. Call set_cache() first.")

        cache_size = self.feature_map_cache.shape[0]
        indices = torch.randint(0, cache_size, (N,))
        samples = self.feature_map_cache[indices]

        # CRITICAL: Center the samples
        samples = center_samples(samples)

        return samples


class DataAwareEGBaselineSampler:
    """Data-aware baseline sampler for Expected Gradients.

    Generates baselines z' by extracting feature maps from real images,
    ensuring baselines stay on the data manifold. This is the recommended
    approach for theoretical correctness.

    From the paper specification:
        φ^{EG}(g, z_0, I; A, D) = E_{z'~D}[∫₀¹ ∇_z g(z' + t(z_0 - I - z'); A) dt]

    The baseline z' should represent a meaningful reference point. Using
    feature maps from real images gives baselines that correspond to
    actual CNN representations, unlike Gaussian noise which may be off-manifold.

    Baselines are computed as z' = GAP(A') where A' are feature maps
    from baseline images, then centered and scaled.

    Attributes:
        model: CNN model for feature extraction.
        target_layer: Layer to extract feature maps from.
        baseline_dataset: Dataset of baseline images.
        device: Torch device.
        batch_size: Batch size for forward passes.

    Example:
        >>> sampler = DataAwareEGBaselineSampler(
        ...     model, target_layer, imagenet_train, device="cuda"
        ... )
        >>> features = hook.features  # [1, 2048, 7, 7]
        >>> baselines = sampler.sample(features, N=20)
        >>> assert baselines.shape == (20, 2048)
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        baseline_dataset: "Dataset",
        device: torch.device | str,
        batch_size: int = 32,
    ) -> None:
        """Initialize data-aware baseline sampler.

        Args:
            model: The CNN model (should be in eval mode).
            target_layer: Target convolutional layer for feature extraction.
            baseline_dataset: Dataset to sample baseline images from.
            device: Torch device.
            batch_size: Batch size for forward passes (for efficiency).
        """
        self.model = model
        self.target_layer = target_layer
        self.baseline_dataset = baseline_dataset
        if isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device
        self.batch_size = batch_size

    def sample(
        self,
        input_feature_maps: Tensor,
        N: int,
        target_scale: float = 0.1,
    ) -> Tensor:
        """Sample N centered baseline vectors from real images.

        The baselines are computed by:
        1. Forward passing N random images from the dataset
        2. Computing z' = GAP(A') (raw feature activation magnitudes)
        3. Centering the samples to satisfy E[z'] = 0
        4. Scaling to target standard deviation for numerical stability

        Args:
            input_feature_maps: Feature maps A of the input image [1, K, U, V].
                Used to determine K, but baselines are sampled independently.
            N: Number of baseline samples.
            target_scale: Target standard deviation for the baselines.
                Default 0.1 matches the scale used in Gaussian baseline sampling.

        Returns:
            Centered baseline samples [N, K] with mean 0 and std ~target_scale.
        """
        K = input_feature_maps.shape[1]
        device = self.device

        # Sample random indices
        n_baselines = len(self.baseline_dataset)
        indices = [random.randint(0, n_baselines - 1) for _ in range(N)]

        z_samples = []
        num_batches = (N + self.batch_size - 1) // self.batch_size

        with FeatureMapHook(self.target_layer) as hook:
            for batch_idx in range(num_batches):
                start = batch_idx * self.batch_size
                end = min(start + self.batch_size, N)
                batch_indices = indices[start:end]

                # Load batch of images
                batch_images = []
                for idx in batch_indices:
                    x_prime, _ = self.baseline_dataset[idx]
                    if x_prime.dim() == 3:
                        x_prime = x_prime.unsqueeze(0)
                    batch_images.append(x_prime)

                batch_tensor = torch.cat(batch_images, dim=0).to(device)

                # Forward pass
                with torch.no_grad():
                    _ = self.model(batch_tensor)

                A_batch = hook.features  # [batch_size, K, U, V]
                baseline_gap = A_batch.mean(dim=(2, 3))  # [batch_size, K]

                # Use raw GAP values (not normalized by input)
                z_samples.append(baseline_gap)

        D_samples = torch.cat(z_samples, dim=0)  # [N, K]

        # CRITICAL: Center the samples to satisfy completeness axiom: E[z'] = 0
        D_samples = center_samples(D_samples)

        # Scale to target standard deviation for numerical stability
        current_std = D_samples.std()
        if current_std > 1e-10:
            D_samples = D_samples * (target_scale / current_std)

        return D_samples

    def sample_raw(self, N: int) -> Tensor:
        """Sample N centered baseline vectors without input normalization.

        This version uses z' = GAP(A') directly without scaling.
        May be useful for certain analysis or when input-relative scaling
        is not desired.

        Args:
            N: Number of baseline samples.

        Returns:
            Centered baseline samples [N, K] with mean 0.
        """
        n_baselines = len(self.baseline_dataset)
        indices = [random.randint(0, n_baselines - 1) for _ in range(N)]

        z_samples = []
        num_batches = (N + self.batch_size - 1) // self.batch_size

        with FeatureMapHook(self.target_layer) as hook:
            for batch_idx in range(num_batches):
                start = batch_idx * self.batch_size
                end = min(start + self.batch_size, N)
                batch_indices = indices[start:end]

                # Load batch
                batch_images = []
                for idx in batch_indices:
                    x_prime, _ = self.baseline_dataset[idx]
                    if x_prime.dim() == 3:
                        x_prime = x_prime.unsqueeze(0)
                    batch_images.append(x_prime)

                batch_tensor = torch.cat(batch_images, dim=0).to(self.device)

                # Forward pass
                with torch.no_grad():
                    _ = self.model(batch_tensor)

                A_batch = hook.features
                z_batch = A_batch.mean(dim=(2, 3))  # [batch_size, K]
                z_samples.append(z_batch)

        D_samples = torch.cat(z_samples, dim=0)

        # CRITICAL: Center the samples
        D_samples = center_samples(D_samples)

        return D_samples


def sample_centered_baselines(
    K: int,
    N: int,
    scale: float = 0.1,
    distribution: Literal["gaussian", "uniform", "sphere"] = "gaussian",
    device: torch.device | str | None = None,
) -> Tensor:
    """Convenience function to sample centered baselines.

    Args:
        K: Dimensionality (number of feature maps).
        N: Number of baseline samples.
        scale: Scale parameter for the distribution.
        distribution: Type of distribution ("gaussian", "uniform", "sphere").
        device: Torch device.

    Returns:
        Centered baseline samples [N, K].

    Example:
        >>> baselines = sample_centered_baselines(K=2048, N=20, scale=0.1)
        >>> assert baselines.shape == (20, 2048)
    """
    sampler = CenteredBaselineSampler(
        distribution=distribution,
        scale=scale,
        device=device,
    )
    return sampler.sample(K, N)


def sample_data_aware_baselines(
    model: nn.Module,
    target_layer: nn.Module,
    baseline_dataset: "Dataset",
    input_feature_maps: Tensor,
    N: int,
    device: torch.device | str,
    batch_size: int = 32,
    target_scale: float = 0.1,
) -> Tensor:
    """Convenience function to sample data-aware baselines for Expected Gradients.

    This replaces the Gaussian baseline sampling with data-aware sampling
    from real images:
        z' = GAP(A') from real images (CORRECT)

    instead of:
        z' = N(0, scale²) (less meaningful)

    Args:
        model: CNN model.
        target_layer: Target convolutional layer.
        baseline_dataset: Dataset to sample from.
        input_feature_maps: Feature maps of the input image [1, K, U, V].
        N: Number of baseline samples.
        device: Torch device.
        batch_size: Batch size for efficiency.
        target_scale: Target standard deviation for baselines.

    Returns:
        Centered baseline samples [N, K].

    Example:
        >>> baselines = sample_data_aware_baselines(
        ...     model, layer, dataset, features, N=20, device="cuda"
        ... )

    Note:
        Consider using ``sample_from_provider`` with a BaselineProvider instead.
        The provider API offers better configuration and caching options.
    """
    sampler = DataAwareEGBaselineSampler(
        model=model,
        target_layer=target_layer,
        baseline_dataset=baseline_dataset,
        device=device,
        batch_size=batch_size,
    )
    return sampler.sample(input_feature_maps, N, target_scale)


def sample_from_provider(
    provider: "BaselineProvider",
    N: int,
    device: torch.device | str,
    target_scale: float | None = None,
) -> Tensor:
    """Sample centered baselines from a baseline provider.

    This is the recommended way to sample data-aware baselines. The provider
    handles data loading, caching, and feature extraction internally.

    Args:
        provider: Initialized BaselineProvider instance.
        N: Number of baseline samples.
        device: Torch device for output tensor.
        target_scale: Optional target standard deviation. If provided,
            samples are scaled to this std. If None, uses provider's
            natural scale.

    Returns:
        Centered baseline samples [N, K] with mean 0.

    Raises:
        RuntimeError: If provider is not initialized.

    Example:
        >>> from expected_gradcam.baselines import baseline_from
        >>>
        >>> provider = baseline_from("/data/imagenet/train")
        >>> provider.initialize(model, target_layer, device)
        >>>
        >>> baselines = sample_from_provider(provider, N=20, device="cuda")
        >>> assert baselines.mean(dim=0).abs().max() < 1e-6  # Centered

    Note:
        The provider must be initialized before sampling. Call
        ``provider.initialize(model, target_layer, device)`` first.
    """
    from expected_gradcam.baselines.protocols import BaselineProvider

    if not isinstance(provider, BaselineProvider):
        raise TypeError(
            f"Expected BaselineProvider, got {type(provider).__name__}"
        )

    if not provider.is_initialized:
        raise RuntimeError(
            "Provider not initialized. Call provider.initialize() first."
        )

    # Get samples from provider (already centered)
    if isinstance(device, str):
        device = torch.device(device)

    samples = provider.get_baseline_samples(N, device)

    # Optionally scale to target std
    if target_scale is not None:
        current_std = samples.std()
        if current_std > 1e-10:
            samples = samples * (target_scale / current_std)

    return samples


# Type hint for provider
if TYPE_CHECKING:
    from expected_gradcam.baselines.protocols import BaselineProvider


__all__ = [
    # Classes
    "CenteredBaselineSampler",
    "DataAwareBaselineSampler",
    "DataAwareEGBaselineSampler",
    # Convenience functions
    "sample_centered_baselines",
    "sample_data_aware_baselines",
    "sample_from_provider",
]
