"""Data-aware perturbation sampling for Expected GradCAM.

This module provides the sequential (non-batched) implementation of
data-aware perturbation sampling.

Mathematical specification (from paper):
    I = z_0 - α * h(x')

where:
    - z_0 = (1, ..., 1) is the reference point
    - α ~ U(0, 1) or linspace(0, 1, M) is the interpolation factor
    - h(x') = GAP(A') / GAP(A) maps baseline to feature map multipliers
    - x' ~ X is sampled from the data distribution (baseline dataset)

Key insight: By using data-aware perturbations, we stay within the
data manifold and avoid out-of-distribution artifacts.

For faster sampling, see batched_perturbation.py which provides
~60x speedup through GPU batching.
"""

from __future__ import annotations

import random
from typing import TYPE_CHECKING

import torch
from torch import nn

from expected_gradcam.hooks import FeatureMapHook

if TYPE_CHECKING:
    from torch import Tensor
    from torch.utils.data import Dataset


class DataAwarePerturbationSampler:
    """Data-aware perturbation sampler for Expected GradCAM.

    Generates perturbations I = z_0 - α * h(x') where:
    - x' ~ X is sampled from the data distribution (baseline dataset)
    - α ~ U(0,1) is a uniform interpolation factor
    - h(x') maps the baseline image to feature map multipliers

    The function h(x') is computed by:
    1. Forward passing x' through the model
    2. Capturing feature maps A' at the target layer
    3. Computing h(x') = GAP(A') / GAP(A) where A is the input's feature maps

    This normalizes the baseline feature maps relative to the input.

    Note: This is the sequential implementation. For GPU-optimized batched
    sampling (~60x faster), use BatchedPerturbationSampler instead.

    Attributes:
        model: CNN model for feature extraction.
        target_layer: Layer to extract feature maps from.
        baseline_dataset: Dataset of baseline images.
        device: Torch device.
        alpha_sampling: Strategy for sampling alpha values.

    Example:
        >>> sampler = DataAwarePerturbationSampler(
        ...     model, target_layer, imagenet_train, device="cuda"
        ... )
        >>> z0 = torch.ones(2048, device="cuda")
        >>> features = hook.features  # [1, 2048, 7, 7]
        >>> perturbations = sampler.sample(z0, features, M=50)
        >>> assert perturbations.shape == (50, 2048)
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        baseline_dataset: "Dataset",
        device: torch.device | str,
        alpha_sampling: str = "uniform",
    ) -> None:
        """Initialize data-aware perturbation sampler.

        Args:
            model: The CNN model.
            target_layer: Target convolutional layer for feature extraction.
            baseline_dataset: Dataset to sample baseline images from.
            device: Torch device.
            alpha_sampling: How to sample α.
                "uniform": α ~ U(0,1) - random
                "linear": α = linspace(0, 1, M) - deterministic grid
        """
        self.model = model
        self.target_layer = target_layer
        self.baseline_dataset = baseline_dataset
        if isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device
        self.alpha_sampling = alpha_sampling

    def sample(
        self,
        z0: Tensor,
        input_feature_maps: Tensor,
        M: int,
    ) -> Tensor:
        """Sample M perturbation vectors.

        Args:
            z0: Reference point [K], typically all ones.
            input_feature_maps: Feature maps A for the input image [1, K, U, V].
            M: Number of perturbation samples to generate.

        Returns:
            Perturbation samples I [M, K].
        """
        K = z0.shape[0]
        I_samples = []

        # Precompute input feature map statistics for normalization
        input_gap = input_feature_maps.mean(dim=(2, 3)).squeeze()  # [K]

        # Sample indices for baseline images
        n_baselines = len(self.baseline_dataset)
        baseline_indices = [random.randint(0, n_baselines - 1) for _ in range(M)]

        # Generate alpha values
        if self.alpha_sampling == "uniform":
            alpha_values = [torch.rand(1, device=self.device) for _ in range(M)]
        elif self.alpha_sampling == "linear":
            alpha_values = [
                torch.tensor([i / (M - 1)], device=self.device)
                if M > 1
                else torch.tensor([0.5], device=self.device)
                for i in range(M)
            ]
        else:
            raise ValueError(f"Unknown alpha_sampling: {self.alpha_sampling}")

        with FeatureMapHook(self.target_layer) as hook:
            for idx, alpha in zip(baseline_indices, alpha_values):
                # Get baseline image
                x_prime, _ = self.baseline_dataset[idx]
                if x_prime.dim() == 3:
                    x_prime = x_prime.unsqueeze(0)
                x_prime = x_prime.to(self.device)

                # Forward pass to get baseline feature maps
                with torch.no_grad():
                    _ = self.model(x_prime)
                A_prime = hook.features  # [1, K, U, V]

                # Compute h(x') = GAP(A') normalized by input GAP
                # This gives relative feature map "strength"
                baseline_gap = A_prime.mean(dim=(2, 3)).squeeze()  # [K]

                # Normalize to get multipliers (avoid division by zero)
                h_x_prime = baseline_gap / (input_gap + 1e-10)  # [K]

                # Compute perturbation: I = z_0 - α * h(x')
                I = z0 - alpha * h_x_prime
                I_samples.append(I)

        return torch.stack(I_samples)  # [M, K]


__all__ = [
    "DataAwarePerturbationSampler",
]
