"""Core Expected GradCAM implementation without segment constraints.

This module provides the full Expected GradCAM algorithm that computes
optimal feature map weights to minimize explanation infidelity.

Mathematical foundation:
    - Predictor: g(z; A) = y^c(z_1 * A^1, ..., z_K * A^K)
    - Reference: z_0 = (1, ..., 1)
    - Infidelity: INFD(α) = E[(I^T α - (g(z_0) - g(z_0 - I)))^2]
    - Optimal weights: α* = M_I^{-1} * E[I * <I, φ>]
    - Heatmap: L^c = ReLU(Σ_k α*_k * A^k)

Example:
    >>> from expected_gradcam import ExpectedGradCAM, ExpectedGradCAMConfig
    >>> config = ExpectedGradCAMConfig(M=50, N=20, T=50)
    >>> egcam = ExpectedGradCAM.core(model, layer, config=config)
    >>> result = egcam.generate(image, class_idx=243)
    >>> heatmap = result.heatmap
    >>> weights = result.optimal_weights
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
import torch.nn.functional as F
from torch import nn

from expected_gradcam.config import DEFAULT_CONFIG, ExpectedGradCAMConfig
from expected_gradcam.core import (
    BatchedPredictor,
    compute_second_moment_matrix,
    compute_cross_moment,
    analyze_second_moment_matrix,
    solve_linear_system_robust,
    transform_weights,
    generate_heatmap,
    upsample_heatmap,
    normalize_heatmap,
    apply_contrast_enhancement,
)
from expected_gradcam.gpu.batched_ops import FullyBatchedExpectedGradients
from expected_gradcam.sampling import sample_centered_baselines
from expected_gradcam.types.results import (
    ExpectedGradCAMResult,
    SolverDiagnostics,
    IntermediateValues,
    CompletenessResult,
)

if TYPE_CHECKING:
    from torch import Tensor


class CoreExpectedGradCAM:
    """Core Expected GradCAM without segment constraints.

    This class exposes the full mathematical pipeline for computing
    optimal feature map weights:
    - Perturbation sampling
    - Path integration (Expected Gradients)
    - Second moment matrix computation
    - Optimal weight solving via linear system
    - Weight transformation
    - Heatmap generation with diagnostics

    This directly uses the optimal weights from the infidelity minimization
    formulation.

    Attributes:
        model: Target CNN model.
        target_layer: Layer to extract feature maps from.
        device: Computation device.
        config: Configuration parameters.

    Example:
        >>> from expected_gradcam import ExpectedGradCAM, ExpectedGradCAMConfig
        >>> config = ExpectedGradCAMConfig(M=50, N=20, T=50)
        >>> egcam = ExpectedGradCAM.core(model, layer, config=config)
        >>> result = egcam.generate(image, class_idx=243)
        >>> print(result.heatmap.shape)        # [H, W]
        >>> print(result.optimal_weights.shape) # [K]
        >>> print(result.solver_diagnostics)   # Solver info
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        config: ExpectedGradCAMConfig | None = None,
        device: torch.device | str | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize core Expected GradCAM.

        Args:
            model: Target CNN model.
            target_layer: Layer to extract feature maps from.
            config: Configuration (uses DEFAULT_CONFIG if None).
            device: Computation device. Auto-detects if None.
            **kwargs: Additional arguments (override config values).
        """
        self.model = model
        self.target_layer = target_layer

        # Auto-detect device
        if device is None:
            device = next(model.parameters()).device
        elif isinstance(device, str):
            device = torch.device(device)
        self.device = device

        # Configuration
        if config is None:
            config = DEFAULT_CONFIG
        self.config = config

        # Hook state
        self._features: Tensor | None = None
        self._gradients: Tensor | None = None
        self._hooks: list[Any] = []

        # Register hooks
        self._register_hooks()

    def _register_hooks(self) -> None:
        """Register forward and backward hooks on target layer."""

        def forward_hook(
            module: nn.Module, input: tuple, output: Tensor
        ) -> None:
            self._features = output.detach()

        def backward_hook(
            module: nn.Module, grad_input: tuple, grad_output: tuple
        ) -> None:
            self._gradients = grad_output[0].detach()

        fwd_handle = self.target_layer.register_forward_hook(forward_hook)
        bwd_handle = self.target_layer.register_full_backward_hook(backward_hook)
        self._hooks = [fwd_handle, bwd_handle]

    def remove_hooks(self) -> None:
        """Remove registered hooks."""
        for hook in self._hooks:
            hook.remove()
        self._hooks = []

    def __del__(self) -> None:
        """Cleanup on deletion."""
        self.remove_hooks()

    def _extract_classifier_head(self) -> nn.Module:
        """Extract the classifier head from the model.

        Returns a module that maps feature maps to class logits.

        Returns:
            Module that performs [B, K, U, V] -> [B, num_classes].
        """
        # This is a simplified extraction. For complex architectures,
        # users may need to override or configure this.
        from expected_gradcam.architectures import extract_classifier_head

        try:
            return extract_classifier_head(self.model, self.target_layer)
        except Exception:
            # Fallback: try to infer from model structure
            # Most CNNs have avgpool + fc as classifier
            children = list(self.model.children())

            # Find components after target layer
            found_target = False
            classifier_modules = []

            for child in children:
                if child is self.target_layer:
                    found_target = True
                    continue
                if found_target:
                    classifier_modules.append(child)

            if classifier_modules:
                return nn.Sequential(*classifier_modules)

            raise RuntimeError(
                "Could not automatically extract classifier head. "
                "Please ensure model architecture is supported."
            )

    def _forward_backward(
        self,
        image: Tensor,
        class_idx: int | None = None,
    ) -> tuple[Tensor, int]:
        """Perform forward and backward pass to extract features.

        Args:
            image: Preprocessed image [1, C, H, W].
            class_idx: Target class. Uses predicted class if None.

        Returns:
            Tuple of (features [1, K, U, V], class_idx).
        """
        self.model.eval()
        image = image.to(self.device).requires_grad_(True)

        # Forward pass
        output = self.model(image)

        if class_idx is None:
            class_idx = output.argmax(dim=1).item()

        # Backward pass (needed to populate gradient hooks)
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0, class_idx] = 1
        output.backward(gradient=one_hot, retain_graph=True)

        return self._features, class_idx

    def _sample_perturbations(self, K: int, M: int) -> Tensor:
        """Sample perturbation vectors I.

        Args:
            K: Number of feature channels.
            M: Number of perturbation samples.

        Returns:
            Perturbation samples [M, K] in [alpha_min, alpha_max].
        """
        alpha_min = self.config.alpha_min
        alpha_max = self.config.alpha_max

        if self.config.alpha_sampling == "uniform":
            # Uniform sampling in [alpha_min, alpha_max]
            I_samples = torch.rand(M, K, device=self.device)
            I_samples = I_samples * (alpha_max - alpha_min) + alpha_min
        else:
            # Linear sampling
            I_samples = torch.linspace(alpha_min, alpha_max, M, device=self.device)
            I_samples = I_samples.unsqueeze(1).expand(-1, K)

        return I_samples

    def generate(
        self,
        image: Tensor,
        class_idx: int | None = None,
    ) -> ExpectedGradCAMResult:
        """Generate heatmap using optimal weights.

        Computes the full Expected GradCAM pipeline:
        1. Extract features via forward pass
        2. Sample perturbations
        3. Compute path integrals (Expected Gradients)
        4. Build second moment matrix M_I
        5. Solve for optimal weights α*
        6. Apply weight transformation
        7. Generate and upsample heatmap

        Args:
            image: Preprocessed image tensor [1, C, H, W].
            class_idx: Target class. Uses predicted class if None.

        Returns:
            ExpectedGradCAMResult with heatmap, weights, and diagnostics.

        Example:
            >>> result = egcam.generate(image, class_idx=243)
            >>> heatmap = result.heatmap  # [H, W]
            >>> weights = result.optimal_weights  # [K]
        """
        H, W = image.shape[2:]

        # Extract features
        features, class_idx = self._forward_backward(image, class_idx)
        K = features.shape[1]
        U, V = features.shape[2:]

        # Get classifier head
        classifier_head = self._extract_classifier_head()

        # Create predictor
        predictor = BatchedPredictor(
            classifier_head,
            target_class=class_idx,
            feature_maps=features,
            use_compile=self.config.use_batching,
        )

        # Sample perturbations
        M = self.config.M
        N = self.config.N
        T = self.config.T

        I_samples = self._sample_perturbations(K, M)

        # Sample baselines for Expected Gradients
        # Using centered Gaussian baselines (data-aware baselines require dataset)
        D_samples = sample_centered_baselines(
            K=K,
            N=N,
            scale=self.config.baseline_scale,
            distribution="gaussian",
            device=self.device,
        )

        # Compute Expected Gradients using FullyBatchedExpectedGradients
        # This batches across M perturbations, N baselines, and T integration steps
        z0 = torch.ones(K, device=self.device)
        eg = FullyBatchedExpectedGradients(T=T, N=N)
        phi_samples = eg.compute_batch(
            predictor_fn=predictor,
            z0=z0,
            I_batch=I_samples,
            D_samples=D_samples,
            use_amp=self.config.use_amp,
        )

        # Compute second moment matrix M_I
        M_I = compute_second_moment_matrix(I_samples)

        # Analyze M_I
        # Returns: (is_invertible, condition_number, rank, eigenvalues)
        is_invertible, condition_number, effective_rank, eigenvalues = (
            analyze_second_moment_matrix(M_I)
        )

        # Compute cross moment b = E[I * <I, φ>]
        b = compute_cross_moment(I_samples, phi_samples)

        # Solve linear system: M_I @ α* = b
        # solve_linear_system_robust returns (alpha, SolverDiagnostics)
        alpha_raw, solver_diagnostics = solve_linear_system_robust(
            M_I,
            b,
            method=self.config.solver_method,
            rcond=self.config.rank_threshold,
            regularization_eps=self.config.regularization_eps,
        )

        # Apply weight transformation
        alpha_transformed = transform_weights(
            alpha_raw,
            method=self.config.weight_transform,
            feature_maps=features,
            exponent=self.config.transform_exponent,
        )

        # Generate heatmap
        coarse_heatmap = generate_heatmap(features, alpha_transformed, apply_relu=True)

        # Upsample to original size
        heatmap = upsample_heatmap(coarse_heatmap, target_size=(H, W))

        # Normalize
        heatmap = normalize_heatmap(
            heatmap,
            method=self.config.normalization_method,
            quantile_low=self.config.quantile_low,
            quantile_high=self.config.quantile_high,
        )

        # Apply contrast enhancement if enabled (matches progressive output quality)
        if self.config.apply_contrast_enhancement:
            heatmap = apply_contrast_enhancement(
                heatmap,
                boost_factor=self.config.contrast_boost_factor,
            )

        # Collect intermediates if requested
        intermediates = None
        if self.config.collect_intermediates:
            # Compute eigenvalues for diagnostics
            eigenvalues = torch.linalg.eigvalsh(M_I)
            intermediates = IntermediateValues(
                I_samples=I_samples.cpu(),
                phi_samples=phi_samples.cpu(),
                M_I=M_I.cpu(),
                b=b.cpu(),
                eigenvalues=eigenvalues.cpu(),
                eigenvectors=None,  # Not computed for efficiency
                alpha_raw=alpha_raw.cpu(),
                alpha_transformed=alpha_transformed.cpu(),
                coarse_heatmap=coarse_heatmap.cpu(),
                timings={},
                extra={},
            )

        # Validate completeness if requested
        completeness_results = None
        if self.config.validate_completeness:
            completeness_results = self._validate_completeness(
                I_samples, phi_samples, predictor
            )

        return ExpectedGradCAMResult(
            heatmap=heatmap.squeeze(),
            coarse_heatmap=coarse_heatmap.squeeze(),
            optimal_weights=alpha_transformed.squeeze(),
            target_class=class_idx,
            feature_maps=features if self.config.collect_intermediates else None,
            solver_diagnostics=solver_diagnostics,
            completeness_results=completeness_results,
            intermediates=intermediates,
        )

    def _validate_completeness(
        self,
        I_samples: Tensor,
        phi_samples: Tensor,
        predictor: BatchedPredictor,
    ) -> list[CompletenessResult]:
        """Validate completeness axiom: I^T @ φ = g(z_0) - g(z_0 - I).

        Args:
            I_samples: Perturbation samples [M, K].
            phi_samples: Attribution samples [M, K].
            predictor: The predictor function.

        Returns:
            List of CompletenessResult for a sample of perturbations.
        """
        results = []
        tolerance = self.config.completeness_tolerance

        # Check a subset of samples
        num_checks = min(10, I_samples.shape[0])
        indices = torch.randperm(I_samples.shape[0])[:num_checks]

        for idx in indices:
            I = I_samples[idx]
            phi = phi_samples[idx]

            lhs = (I * phi).sum()
            rhs = predictor.compute_output_difference(I)

            results.append(
                CompletenessResult.from_tensors(
                    lhs.unsqueeze(0),
                    rhs,
                    tolerance,
                )
            )

        return results

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"CoreExpectedGradCAM("
            f"model={self.model.__class__.__name__}, "
            f"config=M={self.config.M}/N={self.config.N}/T={self.config.T}, "
            f"device={self.device})"
        )
