"""Path integration methods: Integrated Gradients and Expected Gradients.

These methods compute attributions by integrating gradients along paths
from baselines to the reference point.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable

import torch


if TYPE_CHECKING:
    from torch import Tensor
    from expected_gradcam.core.predictor import Predictor


class IntegratedGradients:
    """Integrated Gradients attribution method.

    Computes attributions by integrating gradients along the straight-line
    path from baseline (z_0 - I) to reference (z_0).

    Mathematical specification:
        φ^{IG} = ∫_0^1 ∇_z g(z_0 - I + t·I; A) dt
               = ∫_0^1 ∇_z g(baseline + t·I; A) dt

    Uses midpoint rule for O(1/T²) accuracy vs O(1/T) for simple Riemann sum.

    Properties:
        - Satisfies completeness axiom: I^T · φ^{IG} = g(z_0) - g(z_0 - I)
        - Path-based attribution (integrates along straight line)

    Attributes:
        T: Number of integration steps.

    Example:
        >>> ig = IntegratedGradients(T=50)
        >>> phi = ig.compute(predictor, z0, I)  # Shape: [K]
    """

    def __init__(self, T: int = 50) -> None:
        """Initialize Integrated Gradients.

        Args:
            T: Number of integration steps. Higher = more accurate.
                Midpoint rule gives good results with T=50.
        """
        self.T = T

    def compute(
        self,
        predictor: "Predictor",
        z0: "Tensor",
        I: "Tensor",
    ) -> "Tensor":
        """Compute integrated gradients from baseline to reference.

        Args:
            predictor: The predictor function g(z; A).
            z0: Reference point [K], typically all ones.
            I: Perturbation vector [K].

        Returns:
            Integrated gradients attribution φ^{IG} [K].
        """
        device = z0.device
        K = z0.shape[0]
        baseline = z0 - I

        # Accumulate gradients along path
        grads_sum = torch.zeros(K, device=device)

        for j in range(self.T):
            # Midpoint rule: t = (j + 0.5) / T
            t = (j + 0.5) / self.T

            # Interpolated point on path
            z_interp = baseline + t * I
            z_interp = z_interp.clone().detach().requires_grad_(True)

            # Forward pass
            output = predictor(z_interp)

            # Backward pass to get gradient w.r.t. z
            grad = torch.autograd.grad(
                output.sum(),
                z_interp,
                create_graph=False,
                retain_graph=False,
            )[0]

            grads_sum = grads_sum + grad

        # Average over integration steps
        phi_ig = grads_sum / self.T

        return phi_ig

    def compute_batched(
        self,
        predictor: Callable[["Tensor"], "Tensor"],
        z0: "Tensor",
        I: "Tensor",
    ) -> "Tensor":
        """Compute integrated gradients using batched forward passes.

        More memory-intensive but potentially faster on GPU.

        Args:
            predictor: Batched predictor function g(z_batch; A).
            z0: Reference point [K].
            I: Perturbation vector [K].

        Returns:
            Integrated gradients attribution φ^{IG} [K].
        """
        device = z0.device
        baseline = z0 - I

        # Create all interpolation points at once: [T, K]
        t_values = torch.linspace(
            0.5 / self.T, 1 - 0.5 / self.T, self.T, device=device
        )
        z_batch = baseline.unsqueeze(0) + t_values.unsqueeze(1) * I.unsqueeze(0)

        # Enable gradients
        z_batch = z_batch.clone().detach().requires_grad_(True)

        # Forward pass for all points
        outputs = predictor(z_batch)  # [T]

        # Backward pass
        grads = torch.autograd.grad(
            outputs.sum(),
            z_batch,
            create_graph=False,
            retain_graph=False,
        )[0]  # [T, K]

        # Average over integration steps
        phi_ig = grads.mean(dim=0)

        return phi_ig


class ExpectedGradients:
    """Expected Gradients attribution method.

    Extends Integrated Gradients by averaging over multiple baselines
    sampled from a distribution D.

    Mathematical specification:
        φ^{EG}(g, z_0, I; A, D) = E_{z'~D}[∫_0^1 ∇_z g(z' + t(target - z'); A) dt]

    where target = z_0 - I.

    CRITICAL: For φ^{EG} to satisfy completeness axiom, the baseline
    distribution D must be centered: E[z'] = 0.

    Properties:
        - Satisfies completeness axiom when E[z'] = 0
        - Reduces sensitivity to single baseline choice
        - More robust attributions than standard IG

    Attributes:
        T: Number of integration steps per baseline.

    Example:
        >>> eg = ExpectedGradients(T=50)
        >>> phi = eg.compute(predictor, z0, I, D_samples)  # Shape: [K]
    """

    def __init__(self, T: int = 50) -> None:
        """Initialize Expected Gradients.

        Args:
            T: Number of integration steps per baseline.
        """
        self.T = T
        self._ig = IntegratedGradients(T)

    def compute(
        self,
        predictor: "Predictor",
        z0: "Tensor",
        I: "Tensor",
        D_samples: "Tensor",
    ) -> "Tensor":
        """Compute expected gradients over baseline samples.

        Args:
            predictor: The predictor function g(z; A).
            z0: Reference point [K], typically all ones.
            I: Perturbation vector [K].
            D_samples: Baseline samples [N, K] from distribution D.
                MUST be centered (E[z'] = 0) for completeness axiom.

        Returns:
            Expected gradients attribution φ^{EG} [K].
        """
        device = z0.device
        K = z0.shape[0]
        N = D_samples.shape[0]

        # CRITICAL: Ensure baselines are centered
        D_centered = D_samples - D_samples.mean(dim=0, keepdim=True)

        # Target point for integration
        target = z0 - I

        # Accumulate attributions over baselines
        phi_sum = torch.zeros(K, device=device)

        for n in range(N):
            z_prime = D_centered[n]  # [K]

            # Direction from z' to target
            direction = target - z_prime

            # Integrate along path from z' to target
            grads_sum = torch.zeros(K, device=device)

            for j in range(self.T):
                # Midpoint rule
                t = (j + 0.5) / self.T

                # Point on path from z' to target
                z_interp = z_prime + t * direction
                z_interp = z_interp.clone().detach().requires_grad_(True)

                # Forward and backward
                output = predictor(z_interp)
                grad = torch.autograd.grad(
                    output.sum(),
                    z_interp,
                    create_graph=False,
                    retain_graph=False,
                )[0]

                grads_sum = grads_sum + grad

            # Average over integration steps for this baseline
            phi_ig_sample = grads_sum / self.T
            phi_sum = phi_sum + phi_ig_sample

        # Average over baselines
        phi_eg = phi_sum / N

        return phi_eg


def compute_attribution(
    predictor: "Predictor",
    z0: "Tensor",
    I: "Tensor",
    method: str = "ig",
    T: int = 50,
    D_samples: "Tensor | None" = None,
) -> "Tensor":
    """Convenience function to compute attributions.

    Args:
        predictor: The predictor function g(z; A).
        z0: Reference point [K].
        I: Perturbation vector [K].
        method: "ig" for Integrated Gradients or "eg" for Expected Gradients.
        T: Number of integration steps.
        D_samples: Baseline samples for Expected Gradients [N, K].
            Required if method="eg".

    Returns:
        Attribution vector φ [K].

    Raises:
        ValueError: If method="eg" but D_samples not provided.
        ValueError: If unknown method specified.
    """
    if method == "ig":
        ig = IntegratedGradients(T)
        return ig.compute(predictor, z0, I)
    elif method == "eg":
        if D_samples is None:
            raise ValueError("D_samples required for Expected Gradients")
        eg = ExpectedGradients(T)
        return eg.compute(predictor, z0, I, D_samples)
    else:
        raise ValueError(f"Unknown method: {method}. Use 'ig' or 'eg'.")
