"""Batched perturbation sampling for efficient GPU utilization.

This module provides optimized sampling that processes multiple baseline
images in batched forward passes instead of sequential single-image passes.

Performance comparison for M=2500 perturbations:
- Sequential: 2500 forward passes
- Batched (batch_size=64): ~40 forward passes
- Speedup: ~60x for perturbation sampling phase

Mathematical specification:
    I = z_0 - α * h(x')

where h(x') = GAP(A') / GAP(A_input) maps baselines to feature multipliers.

NOTE: Pure data-aware perturbations are CORRECT per the paper specification.
The rank deficiency (~900/2048 for ResNet-50) is a fundamental property of
CNN representations and should be handled with pseudo-inverse, NOT by mixing
Gaussian noise in z-space.
"""

from __future__ import annotations

import random
from typing import TYPE_CHECKING

import torch
from torch import nn
from torch.utils.data import DataLoader

from expected_gradcam.hooks import FeatureMapHook
from expected_gradcam.sampling.utils import safe_divide, normalize_perturbations

if TYPE_CHECKING:
    from torch import Tensor
    from torch.utils.data import Dataset


class BatchedPerturbationSampler:
    """Batched data-aware perturbation sampler.

    Generates perturbations I = z_0 - alpha * h(x') using batched forward
    passes through the model to extract feature maps for multiple baseline
    images at once.

    This is ~batch_size times faster than the sequential sampler
    (DataAwarePerturbationSampler).

    Key optimizations:
    - Batched forward passes reduce GPU kernel launch overhead
    - Robust normalization with median-based thresholding
    - Automatic centering and scaling for numerical stability

    Attributes:
        model: CNN model for feature extraction.
        target_layer: Layer to extract feature maps from.
        baseline_dataset: Dataset of baseline images.
        device: Torch device.
        batch_size: Batch size for forward passes.
        alpha_sampling: Strategy for sampling alpha values.
        num_workers: DataLoader workers for I/O.

    Example:
        >>> sampler = BatchedPerturbationSampler(
        ...     model, target_layer, imagenet_train, device="cuda", batch_size=64
        ... )
        >>> z0 = torch.ones(2048, device="cuda")
        >>> features = hook.features
        >>> perturbations = sampler.sample(z0, features, M=2500)
        >>> # ~40 forward passes instead of 2500
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        baseline_dataset: "Dataset",
        device: torch.device | str,
        batch_size: int = 64,
        alpha_sampling: str = "uniform",
        num_workers: int = 0,
    ) -> None:
        """Initialize batched perturbation sampler.

        Args:
            model: The CNN model (should be in eval mode).
            target_layer: Target convolutional layer for feature extraction.
            baseline_dataset: Dataset to sample baseline images from.
            device: Torch device.
            batch_size: Batch size for forward passes. Larger = faster but more memory.
            alpha_sampling: How to sample alpha values.
                "uniform": alpha ~ U(0, 1)
                "linear": alpha = linspace(0, 1, M)
            num_workers: DataLoader workers. 0 for main thread.
        """
        self.model = model
        self.target_layer = target_layer
        self.baseline_dataset = baseline_dataset
        if isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device
        self.batch_size = batch_size
        self.alpha_sampling = alpha_sampling
        self.num_workers = num_workers

    def sample(
        self,
        z0: Tensor,
        input_feature_maps: Tensor,
        M: int,
        target_scale: float = 0.3,
    ) -> Tensor:
        """Sample M perturbation vectors using batched forward passes.

        The perturbations are computed as:
            I = z_0 - alpha * h(x')

        where h(x') represents the relative feature activation ratio.

        Uses robust normalization to avoid numerical instability from
        division by small values:
        1. Compute ratio using safe division with proper clamping
        2. Center and scale the final perturbations

        Args:
            z0: Reference point [K], typically all ones.
            input_feature_maps: Feature maps A for the input image [1, K, U, V].
            M: Number of perturbation samples.
            target_scale: Target standard deviation for perturbations.

        Returns:
            Perturbation samples I [M, K].
        """
        K = z0.shape[0]
        device = self.device

        # Precompute input GAP for normalization
        input_gap = input_feature_maps.mean(dim=(2, 3)).squeeze()  # [K]

        # Use a robust minimum threshold based on input statistics
        # This prevents division by very small values
        gap_abs = input_gap.abs()
        gap_median = gap_abs.median()
        min_threshold = max(float(gap_median.item()) * 0.01, 1e-6)
        input_gap_safe = gap_abs.clamp(min=min_threshold) * input_gap.sign()

        # For zero values, use positive threshold
        input_gap_safe = torch.where(
            input_gap.abs() < 1e-10,
            torch.full_like(input_gap, min_threshold),
            input_gap_safe,
        )

        # Sample random indices into baseline dataset
        n_baselines = len(self.baseline_dataset)
        indices = [random.randint(0, n_baselines - 1) for _ in range(M)]

        # Generate alpha values
        if self.alpha_sampling == "uniform":
            alphas = torch.rand(M, device=device)
        elif self.alpha_sampling == "linear":
            if M > 1:
                alphas = torch.linspace(0, 1, M, device=device)
            else:
                alphas = torch.tensor([0.5], device=device)
        else:
            raise ValueError(f"Unknown alpha_sampling: {self.alpha_sampling}")

        # Process in batches
        I_samples = []
        num_batches = (M + self.batch_size - 1) // self.batch_size

        with FeatureMapHook(self.target_layer) as hook:
            for batch_idx in range(num_batches):
                start = batch_idx * self.batch_size
                end = min(start + self.batch_size, M)
                batch_indices = indices[start:end]
                batch_alphas = alphas[start:end]  # [batch_size]

                # Load batch of images
                batch_images = []
                for idx in batch_indices:
                    x_prime, _ = self.baseline_dataset[idx]
                    if x_prime.dim() == 3:
                        x_prime = x_prime.unsqueeze(0)
                    batch_images.append(x_prime)

                # Stack into batch: [batch_size, C, H, W]
                batch_tensor = torch.cat(batch_images, dim=0).to(device)

                # Single batched forward pass
                with torch.no_grad():
                    _ = self.model(batch_tensor)

                # Get feature maps: [batch_size, K, U, V]
                A_batch = hook.features

                # Compute h(x') = GAP(A') / GAP(A_input) with safe division
                baseline_gap = A_batch.mean(dim=(2, 3))  # [batch_size, K]
                h_batch = baseline_gap / input_gap_safe.unsqueeze(0)  # [batch_size, K]

                # Clamp extreme values to prevent numerical issues
                h_batch = h_batch.clamp(-100, 100)

                # Compute perturbations: I = z0 - alpha * h(x')
                # z0: [K], batch_alphas: [batch_size], h_batch: [batch_size, K]
                I_batch = z0.unsqueeze(0) - batch_alphas.unsqueeze(1) * h_batch
                I_samples.append(I_batch)

        I_all = torch.cat(I_samples, dim=0)  # [M, K]

        # Center and scale perturbations for numerical stability
        I_normalized = normalize_perturbations(I_all, target_scale=target_scale)

        return I_normalized

    def sample_with_dataloader(
        self,
        z0: Tensor,
        input_feature_maps: Tensor,
        M: int,
    ) -> Tensor:
        """Alternative sampling using PyTorch DataLoader for potentially faster I/O.

        Useful when dataset loading is a bottleneck (e.g., loading from disk).

        Args:
            z0: Reference point [K].
            input_feature_maps: Input feature maps [1, K, U, V].
            M: Number of samples.

        Returns:
            Perturbation samples [M, K].
        """
        K = z0.shape[0]
        device = self.device

        input_gap = input_feature_maps.mean(dim=(2, 3)).squeeze()

        # Create a subset sampler for random sampling
        n_baselines = len(self.baseline_dataset)
        indices = torch.randint(0, n_baselines, (M,)).tolist()

        # Create subset dataset
        class SubsetByIndices(torch.utils.data.Dataset):
            def __init__(self, dataset, indices):
                self.dataset = dataset
                self.indices = indices

            def __len__(self):
                return len(self.indices)

            def __getitem__(self, idx):
                return self.dataset[self.indices[idx]]

        subset = SubsetByIndices(self.baseline_dataset, indices)
        loader = DataLoader(
            subset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=device.type == "cuda",
        )

        # Generate alpha values
        if self.alpha_sampling == "uniform":
            alphas = torch.rand(M, device=device)
        else:
            alphas = (
                torch.linspace(0, 1, M, device=device)
                if M > 1
                else torch.tensor([0.5], device=device)
            )

        I_samples = []
        sample_idx = 0

        with FeatureMapHook(self.target_layer) as hook:
            for batch_images, _ in loader:
                batch_images = batch_images.to(device)
                current_batch_size = batch_images.shape[0]

                with torch.no_grad():
                    _ = self.model(batch_images)

                A_batch = hook.features
                baseline_gap = A_batch.mean(dim=(2, 3))
                h_batch = baseline_gap / (input_gap.unsqueeze(0) + 1e-10)

                batch_alphas = alphas[sample_idx : sample_idx + current_batch_size]
                I_batch = z0.unsqueeze(0) - batch_alphas.unsqueeze(1) * h_batch
                I_samples.append(I_batch)

                sample_idx += current_batch_size

        return torch.cat(I_samples, dim=0)[:M]


class SimpleBatchedSampler:
    """Simple batched sampler using Gaussian noise (no baseline dataset required).

    Much faster than data-aware sampling but may produce less meaningful
    perturbations since it doesn't stay on the data manifold.

    Use this when:
    - No baseline dataset is available
    - Speed is critical and theoretical guarantees are less important
    - For quick experiments or prototyping

    Attributes:
        scale: Standard deviation for Gaussian perturbations.
        device: Torch device.

    Example:
        >>> sampler = SimpleBatchedSampler(scale=0.3, device="cuda")
        >>> z0 = torch.ones(2048, device="cuda")
        >>> perturbations = sampler.sample(z0, M=2500)  # Instant
    """

    def __init__(
        self,
        scale: float = 0.3,
        device: torch.device | str | None = None,
    ) -> None:
        """Initialize simple Gaussian sampler.

        Args:
            scale: Standard deviation for Gaussian perturbations.
            device: Torch device.
        """
        self.scale = scale
        if device is None:
            self.device = torch.device("cpu")
        elif isinstance(device, str):
            self.device = torch.device(device)
        else:
            self.device = device

    def sample(self, z0: Tensor, M: int) -> Tensor:
        """Sample M perturbations using Gaussian noise.

        This is extremely fast (just random number generation).

        Args:
            z0: Reference point [K].
            M: Number of samples.

        Returns:
            Perturbation samples [M, K].
        """
        K = z0.shape[0]
        return torch.randn(M, K, device=self.device) * self.scale


__all__ = [
    "BatchedPerturbationSampler",
    "SimpleBatchedSampler",
]
