"""Batched path integration methods for efficient GPU utilization.

This module provides optimized versions of Integrated Gradients and Expected
Gradients that batch computations across:
- T: Integration steps (batch all steps in single forward/backward)
- N: Baseline samples for Expected Gradients
- M: Multiple perturbations (process multiple I vectors simultaneously)

Optimizations:
- Automatic Mixed Precision (AMP) for 1.5-2x speedup and 50% memory reduction
- Precomputed shared tensors to reduce per-iteration overhead
- Preallocated output tensors to eliminate list append/concat
- no_grad wrapper for tensor construction to reduce autograd overhead

Achieves ~3-5x speedup over sequential implementation for large M (>2000).

NOTE: The baseline distribution D for Expected Gradients should be data-aware
(sampled from real image feature maps), not Gaussian noise. This is critical
for the theoretical guarantees of the method.
"""

from __future__ import annotations

import math
import time
from typing import TYPE_CHECKING, Callable

import torch
from torch import Tensor

from expected_gradcam.gpu.amp_context import AMPContext, is_amp_available

if TYPE_CHECKING:
    from expected_gradcam.core.callbacks import ChunkResult, IntermediateHeatmap
    from expected_gradcam.core.observer_manager import ObserverManager
    from expected_gradcam.core.optimal_weights import SolverDiagnostics


class BatchedIntegratedGradients:
    """Batched Integrated Gradients that processes all T steps simultaneously.

    Instead of T sequential forward/backward passes, performs one batched
    forward pass with T*batch_size points and one backward pass.

    Memory: O(T * K) per perturbation
    Speedup: ~T times faster than sequential (GPU parallelism)

    Attributes:
        T: Number of integration steps.
    """

    def __init__(self, T: int = 50) -> None:
        """Initialize batched IG.

        Args:
            T: Number of integration steps.
        """
        self.T = T

    def compute_single(
        self,
        predictor_fn: Callable[[Tensor], Tensor],
        z0: Tensor,
        I: Tensor,
    ) -> Tensor:
        """Compute IG for a single perturbation using batched integration.

        Args:
            predictor_fn: Function mapping [N, K] -> [N] class scores.
            z0: Reference point [K].
            I: Perturbation vector [K].

        Returns:
            Integrated gradients [K].
        """
        device = z0.device
        baseline = z0 - I

        # Create all interpolation points: [T, K]
        # Using midpoint rule: t = (j + 0.5) / T for j in 0..T-1
        t_values = torch.linspace(
            0.5 / self.T, 1 - 0.5 / self.T, self.T, device=device
        )
        z_batch = baseline.unsqueeze(0) + t_values.unsqueeze(1) * I.unsqueeze(0)

        # Enable gradients
        z_batch = z_batch.clone().detach().requires_grad_(True)

        # Single batched forward pass
        outputs = predictor_fn(z_batch)

        # Single backward pass
        grads = torch.autograd.grad(
            outputs.sum(),
            z_batch,
            create_graph=False,
            retain_graph=False,
        )[0]

        # Average over integration steps
        return grads.mean(dim=0)

    def compute_multi(
        self,
        predictor_fn: Callable[[Tensor], Tensor],
        z0: Tensor,
        I_batch: Tensor,
        max_batch_size: int = 1024,
    ) -> Tensor:
        """Compute IG for multiple perturbations efficiently.

        Batches both across T steps and M perturbations for maximum GPU utilization.

        Args:
            predictor_fn: Function mapping [N, K] -> [N] class scores.
            z0: Reference point [K].
            I_batch: Multiple perturbations [M, K].
            max_batch_size: Maximum batch size for GPU memory management.

        Returns:
            Integrated gradients [M, K].
        """
        device = z0.device
        M, K = I_batch.shape

        # Create t values once
        t_values = torch.linspace(
            0.5 / self.T, 1 - 0.5 / self.T, self.T, device=device
        )

        # Compute baselines for all perturbations: [M, K]
        baselines = z0.unsqueeze(0) - I_batch

        # Total points: M * T
        total_points = M * self.T

        if total_points <= max_batch_size:
            # Create all interpolation points at once: [M, T, K]
            z_all = baselines.unsqueeze(1) + t_values.view(1, -1, 1) * I_batch.unsqueeze(
                1
            )
            z_all = z_all.reshape(-1, K)

            z_all = z_all.clone().detach().requires_grad_(True)
            outputs = predictor_fn(z_all)

            grads = torch.autograd.grad(
                outputs.sum(),
                z_all,
                create_graph=False,
                retain_graph=False,
            )[0]

            # Reshape and average over T
            grads = grads.reshape(M, self.T, K)
            return grads.mean(dim=1)

        # Process in chunks
        results = []
        n_chunks = math.ceil(total_points / max_batch_size)
        chunk_size_M = max(1, M // n_chunks)

        for start in range(0, M, chunk_size_M):
            end = min(start + chunk_size_M, M)
            chunk_M = end - start

            I_chunk = I_batch[start:end]
            baselines_chunk = baselines[start:end]

            z_chunk = baselines_chunk.unsqueeze(1) + t_values.view(
                1, -1, 1
            ) * I_chunk.unsqueeze(1)
            z_chunk = z_chunk.reshape(-1, K)

            z_chunk = z_chunk.clone().detach().requires_grad_(True)
            outputs = predictor_fn(z_chunk)

            grads = torch.autograd.grad(
                outputs.sum(),
                z_chunk,
                create_graph=False,
                retain_graph=False,
            )[0]

            grads = grads.reshape(chunk_M, self.T, K)
            results.append(grads.mean(dim=1))

        return torch.cat(results, dim=0)


class BatchedExpectedGradients:
    """Batched Expected Gradients computation.

    Batches across:
    - T: Integration steps
    - N: Baseline samples

    For a single perturbation with N=20, T=50:
    - Sequential: 1000 forward/backward passes
    - Batched: 1 forward/backward pass with batch size 1000

    Memory requirement: O(N * T * K) per perturbation
    """

    def __init__(self, T: int = 50) -> None:
        """Initialize batched EG.

        Args:
            T: Number of integration steps.
        """
        self.T = T

    def compute_single(
        self,
        predictor_fn: Callable[[Tensor], Tensor],
        z0: Tensor,
        I: Tensor,
        D_samples: Tensor,
        max_batch_size: int = 2048,
    ) -> Tensor:
        """Compute Expected Gradients for a single perturbation.

        Args:
            predictor_fn: Function mapping [N, K] -> [N] class scores.
            z0: Reference point [K].
            I: Perturbation vector [K].
            D_samples: Centered baseline samples [N, K].
            max_batch_size: Maximum points per batch.

        Returns:
            Expected gradients attribution [K].
        """
        device = z0.device
        K = z0.shape[0]
        N = D_samples.shape[0]

        # Ensure baselines are centered
        D_centered = D_samples - D_samples.mean(dim=0, keepdim=True)

        # Target point
        target = z0 - I

        # Direction vectors from each baseline to target: [N, K]
        directions = target.unsqueeze(0) - D_centered

        # Create t values
        t_values = torch.linspace(
            0.5 / self.T, 1 - 0.5 / self.T, self.T, device=device
        )

        # Total points: N * T
        total_points = N * self.T

        if total_points <= max_batch_size:
            # Create all interpolation points: [N, T, K]
            z_all = D_centered.unsqueeze(1) + t_values.view(
                1, -1, 1
            ) * directions.unsqueeze(1)
            z_all = z_all.reshape(-1, K)

            z_all = z_all.clone().detach().requires_grad_(True)
            outputs = predictor_fn(z_all)

            grads = torch.autograd.grad(
                outputs.sum(),
                z_all,
                create_graph=False,
                retain_graph=False,
            )[0]

            # Reshape: [N, T, K]
            grads = grads.reshape(N, self.T, K)

            # Average over T (per baseline), then over N
            phi_per_baseline = grads.mean(dim=1)
            return phi_per_baseline.mean(dim=0)

        # Process baselines in chunks
        phi_sum = torch.zeros(K, device=device)
        chunk_size_N = max(1, max_batch_size // self.T)

        for start in range(0, N, chunk_size_N):
            end = min(start + chunk_size_N, N)
            chunk_N = end - start

            D_chunk = D_centered[start:end]
            dir_chunk = directions[start:end]

            z_chunk = D_chunk.unsqueeze(1) + t_values.view(
                1, -1, 1
            ) * dir_chunk.unsqueeze(1)
            z_chunk = z_chunk.reshape(-1, K)

            z_chunk = z_chunk.clone().detach().requires_grad_(True)
            outputs = predictor_fn(z_chunk)

            grads = torch.autograd.grad(
                outputs.sum(),
                z_chunk,
                create_graph=False,
                retain_graph=False,
            )[0]

            grads = grads.reshape(chunk_N, self.T, K)
            phi_sum = phi_sum + grads.mean(dim=1).sum(dim=0)

        return phi_sum / N


class FullyBatchedExpectedGradients:
    """Expected Gradients computation batched across all dimensions: M, N, and T.

    This is the most efficient implementation for computing phi^{EG} for
    multiple perturbations simultaneously.

    Optimizations:
    - AMP (FP16) for gradient computation: 1.5-2x speedup, 50% memory reduction
    - Precomputed shared tensors: reduces per-iteration overhead
    - Preallocated output: eliminates list append/concat
    - no_grad for tensor construction: reduces autograd overhead

    For M=2500, N=20, T=50 on 80GB GPU:
    - Previous: ~625 batched passes (max_batch_size=4096)
    - Optimized: ~80 batched passes (max_batch_size=32000)
    - With AMP: additional 1.5-2x speedup

    Attributes:
        T: Number of integration steps.
        N: Number of baseline samples per perturbation.
    """

    def __init__(self, T: int = 50, N: int = 20) -> None:
        """Initialize fully batched EG.

        Args:
            T: Number of integration steps.
            N: Number of baseline samples per perturbation.
        """
        self.T = T
        self.N = N

    def compute_batch(
        self,
        predictor_fn: Callable[[Tensor], Tensor],
        z0: Tensor,
        I_batch: Tensor,
        D_samples: Tensor | None = None,
        baseline_scale: float = 0.1,
        max_batch_size: int = 4096,
        shared_baselines: bool = True,
        use_amp: bool = True,
        *,
        observer_manager: "ObserverManager | None" = None,
        feature_maps: Tensor | None = None,
        target_size: tuple[int, int] | None = None,
        heatmap_checkpoint_interval: int = 0,
    ) -> Tensor:
        """Compute Expected Gradients for multiple perturbations.

        Args:
            predictor_fn: Function mapping [N, K] -> [N] class scores.
            z0: Reference point [K].
            I_batch: Perturbation vectors [M, K].
            D_samples: Pre-computed baseline samples [N, K]. RECOMMENDED: Use
                data-aware baselines from DataAwareEGBaselineSampler. If None,
                falls back to Gaussian baselines (not recommended).
            baseline_scale: Scale for Gaussian baseline sampling (only used if
                D_samples is None).
            max_batch_size: Maximum points per forward/backward pass.
            shared_baselines: If True, use same N baselines for all M perturbations
                (faster, slightly less variance). If False, sample fresh baselines
                for each perturbation. Only used if D_samples is None.
            use_amp: Whether to use automatic mixed precision (FP16) for gradient
                computation. Provides 1.5-2x speedup and 50% memory reduction.
            observer_manager: Optional observer manager for real-time callbacks.
            feature_maps: Feature maps [1, K, U, V] for intermediate heatmap
                generation. Only used when observer_manager is set.
            target_size: Target size (H, W) for upsampling intermediate heatmaps.
                Only used when feature_maps is provided.
            heatmap_checkpoint_interval: Generate intermediate heatmap every N
                chunks. 0 disables intermediate heatmap generation.

        Returns:
            Expected gradients attributions [M, K].
        """
        device = z0.device
        M, K = I_batch.shape
        use_amp = use_amp and is_amp_available() and device.type == "cuda"

        # Setup for observer callbacks
        has_observers = observer_manager is not None and observer_manager.has_observers
        moment_computer = None
        start_time = None

        if has_observers:
            from expected_gradcam.core.second_moment import IncrementalMomentComputer

            moment_computer = IncrementalMomentComputer(K, device)
            start_time = time.perf_counter()

        # Use provided baselines or sample
        if D_samples is None:
            # Fallback: Sample Gaussian baselines (NOT RECOMMENDED)
            D_samples = torch.randn(self.N, K, device=device) * baseline_scale
            D_samples = D_samples - D_samples.mean(dim=0, keepdim=True)

        # Precompute t values and expanded tensors for broadcasting
        t_values = torch.linspace(
            0.5 / self.T, 1 - 0.5 / self.T, self.T, device=device
        )
        t_expanded = t_values.view(1, 1, -1, 1)

        # Compute targets: [M, K]
        targets = z0.unsqueeze(0) - I_batch

        # Total computation points: M * N * T
        total_points = M * self.N * self.T

        if total_points <= max_batch_size:
            return self._compute_all_at_once(
                predictor_fn,
                D_samples,
                targets,
                t_values,
                K,
                use_amp,
            )

        # Process in chunks over M
        points_per_I = self.N * self.T
        chunk_size_M = max(1, max_batch_size // points_per_I)
        total_chunks = math.ceil(M / chunk_size_M)

        # Preallocate output tensor
        phi_all = torch.empty(M, K, device=device, dtype=I_batch.dtype)

        # Precompute D_samples expansion for broadcasting
        D_expanded = D_samples.unsqueeze(0).unsqueeze(2)

        # AMP context
        amp_ctx = AMPContext(enabled=use_amp)

        chunk_idx = 0
        for start in range(0, M, chunk_size_M):
            end = min(start + chunk_size_M, M)
            chunk_M = end - start

            targets_chunk = targets[start:end]

            # Use no_grad for tensor construction to reduce autograd overhead
            with torch.no_grad():
                # Compute directions: [chunk_M, N, K]
                directions = targets_chunk.unsqueeze(1) - D_samples.unsqueeze(0)

                # Create interpolation points: [chunk_M, N, T, K]
                z_all = D_expanded + t_expanded * directions.unsqueeze(2)

                # Reshape for batched forward: [chunk_M * N * T, K]
                z_flat = z_all.reshape(-1, K)

            # Enable gradients for the computation
            z_flat = z_flat.requires_grad_(True)

            # Forward with optional AMP, backward outside autocast
            with amp_ctx.forward_context():
                outputs = predictor_fn(z_flat)

            grads = torch.autograd.grad(
                outputs.sum(),
                z_flat,
                create_graph=False,
                retain_graph=False,
            )[0]

            # Reshape back: [chunk_M, N, T, K]
            grads = grads.reshape(chunk_M, self.N, self.T, K)

            # Average: T -> per baseline, then N -> per perturbation
            phi_per_baseline = grads.mean(dim=2)
            phi_per_I = phi_per_baseline.mean(dim=1)

            # Write directly to preallocated output
            phi_all[start:end] = phi_per_I

            # Notify observers if registered
            if has_observers and moment_computer is not None:
                I_chunk = I_batch[start:end]
                moment_computer.update(I_chunk, phi_per_I)

                # Compute partial weights (pinv solver handles under-determined systems)
                partial_alpha = None
                condition_number = None
                # Only require a minimum number of samples for stability
                min_samples_for_alpha = min(K // 4, 10)  # At least K/4 or 10 samples
                if moment_computer.num_samples >= min_samples_for_alpha:
                    try:
                        from expected_gradcam.core.optimal_weights import (
                            solve_linear_system_robust,
                        )

                        partial_M_I = moment_computer.get_current_M_I()
                        partial_b = moment_computer.get_current_b()
                        partial_alpha, diag = solve_linear_system_robust(
                            partial_M_I, partial_b, method="pinv"
                        )
                        condition_number = diag.condition_number
                    except Exception:
                        pass

                # Create and emit chunk result
                from expected_gradcam.core.callbacks import ChunkResult

                chunk_result = ChunkResult(
                    chunk_idx=chunk_idx,
                    total_chunks=total_chunks,
                    samples_processed=moment_computer.num_samples,
                    total_samples=M,
                    partial_M_I=moment_computer.get_current_M_I().detach(),
                    partial_b=moment_computer.get_current_b().detach(),
                    partial_alpha=partial_alpha.detach() if partial_alpha is not None else None,
                    condition_number=condition_number,
                    elapsed_seconds=time.perf_counter() - start_time,
                )
                observer_manager.notify_chunk_complete(chunk_result)

                # Generate intermediate heatmap at checkpoints
                should_emit_heatmap = (
                    heatmap_checkpoint_interval > 0
                    and partial_alpha is not None
                    and feature_maps is not None
                    and (chunk_idx + 1) % heatmap_checkpoint_interval == 0
                )
                if should_emit_heatmap:
                    from expected_gradcam.core.callbacks import IntermediateHeatmap
                    from expected_gradcam.core.heatmap import (
                        generate_heatmap,
                        upsample_heatmap,
                    )

                    coarse = generate_heatmap(feature_maps, partial_alpha)
                    full = (
                        upsample_heatmap(coarse, target_size)
                        if target_size
                        else coarse
                    )

                    heatmap_result = IntermediateHeatmap(
                        checkpoint_idx=(chunk_idx + 1) // heatmap_checkpoint_interval,
                        samples_processed=moment_computer.num_samples,
                        total_samples=M,
                        coarse_heatmap=coarse.detach(),
                        full_heatmap=full.detach(),
                        weights=partial_alpha.detach(),
                        condition_number=condition_number or 0.0,
                    )
                    observer_manager.notify_intermediate_heatmap(heatmap_result)

            chunk_idx += 1

        return phi_all

    def _compute_all_at_once(
        self,
        predictor_fn: Callable[[Tensor], Tensor],
        D_samples: Tensor,
        targets: Tensor,
        t_values: Tensor,
        K: int,
        use_amp: bool,
    ) -> Tensor:
        """Compute all perturbations in a single batch."""
        M = targets.shape[0]
        device = targets.device

        # AMP context
        amp_ctx = AMPContext(enabled=use_amp)

        with torch.no_grad():
            # Directions: [M, N, K]
            directions = targets.unsqueeze(1) - D_samples.unsqueeze(0)

            # All points: [M, N, T, K]
            z_all = D_samples.unsqueeze(0).unsqueeze(2) + t_values.view(
                1, 1, -1, 1
            ) * directions.unsqueeze(2)

            z_flat = z_all.reshape(-1, K)

        z_flat = z_flat.requires_grad_(True)

        with amp_ctx.forward_context():
            outputs = predictor_fn(z_flat)

        grads = torch.autograd.grad(outputs.sum(), z_flat)[0]

        grads = grads.reshape(M, self.N, self.T, K)
        return grads.mean(dim=2).mean(dim=1)


def compute_optimal_weights_batched(
    predictor_fn: Callable[[Tensor], Tensor],
    z0: Tensor,
    I_samples: Tensor,
    T: int = 50,
    N: int = 20,
    D_samples: Tensor | None = None,
    baseline_scale: float = 0.1,
    regularization_eps: float = 1e-6,
    max_batch_size: int = 4096,
    solver_method: str = "pinv",
    rank_threshold: float = 1e-6,
    use_amp: bool = True,
) -> tuple[Tensor, Tensor, Tensor, "SolverDiagnostics | None"]:
    """Compute optimal weights using fully batched Expected Gradients.

    This is the main entry point for efficient optimal weight computation.

    Optimizations (enabled by default):
    - AMP (FP16) for gradient computation: 1.5-2x speedup
    - Aggressive batch sizing for 80GB+ GPUs: 8x fewer iterations
    - Precomputed tensors and preallocated output

    Args:
        predictor_fn: Function mapping [batch, K] -> [batch] class scores.
        z0: Reference point [K].
        I_samples: Perturbation samples [M, K].
        T: Integration steps.
        N: Baseline samples.
        D_samples: Pre-computed baseline samples [N, K] for Expected Gradients.
            RECOMMENDED: Use data-aware baselines from DataAwareEGBaselineSampler.
            If None, falls back to Gaussian baselines (not recommended).
        baseline_scale: Scale for baseline sampling. Only used if D_samples is None.
        regularization_eps: Regularization for M_I inversion.
        max_batch_size: Maximum batch size for GPU.
        solver_method: Solver for M_I @ alpha = b:
            - "pinv": Pseudo-inverse (recommended for rank-deficient)
            - "adaptive_reg": Adaptive Tikhonov regularization
            - "subspace": Eigenspace projection
            - "regularized": Simple regularization
        rank_threshold: Threshold for determining significant eigenvalues.
        use_amp: Whether to use automatic mixed precision (FP16) for gradient
            computation. Provides 1.5-2x speedup. Default: True.

    Returns:
        Tuple of:
        - alpha_opt: Optimal weights [K]
        - M_I: Second moment matrix [K, K]
        - phi_samples: Attribution samples [M, K]
        - solver_diagnostics: Diagnostics from solver (or None if regularized)
    """
    from expected_gradcam.core.optimal_weights import solve_linear_system_robust

    device = z0.device
    M, K = I_samples.shape

    # Compute all attributions using batched EG with optimizations
    eg = FullyBatchedExpectedGradients(T=T, N=N)
    phi_samples = eg.compute_batch(
        predictor_fn=predictor_fn,
        z0=z0,
        I_batch=I_samples,
        D_samples=D_samples,
        baseline_scale=baseline_scale,
        max_batch_size=max_batch_size,
        shared_baselines=True,
        use_amp=use_amp,
    )

    # Compute second moment matrix: M_I = (1/M) * I^T @ I
    M_I = torch.mm(I_samples.T, I_samples) / M

    # Compute b = E[I * <I, phi>]
    inner_prods = (I_samples * phi_samples).sum(dim=1)
    b = (I_samples * inner_prods.unsqueeze(1)).mean(dim=0)

    # Solve M_I @ alpha = b using the specified method
    if solver_method == "regularized":
        M_I_reg = M_I + regularization_eps * torch.eye(K, device=device)
        alpha_opt = torch.linalg.lstsq(M_I_reg, b.unsqueeze(1)).solution.squeeze()
        solver_diagnostics = None
    else:
        alpha_opt, solver_diagnostics = solve_linear_system_robust(
            M_I=M_I,
            b=b,
            method=solver_method,
            rcond=rank_threshold,
            regularization_eps=regularization_eps,
        )

    return alpha_opt, M_I, phi_samples, solver_diagnostics
