"""Expected GradCAM weight computer for SC methods.

This module provides a reusable weight computation component that encapsulates
the Expected GradCAM optimal weight algorithm. It is designed to be used by
Segment-Constrained (SC) methods to compute optimal weights (α* = M_I^{-1} * b)
instead of simple GradCAM weights (α = GAP(gradients)).

Example:
    >>> from expected_gradcam.core.weight_computer import ExpectedGradCAMWeightComputer
    >>> from expected_gradcam.config import ExpectedGradCAMConfig
    >>>
    >>> computer = ExpectedGradCAMWeightComputer(
    ...     model=model,
    ...     target_layer=model.layer4[-1],
    ...     config=ExpectedGradCAMConfig(M=100, N=20, T=50),
    ... )
    >>> weights, diagnostics = computer.compute_weights(features, class_idx=243)
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from torch import nn

from expected_gradcam.architectures import extract_classifier_head
from expected_gradcam.config import DEFAULT_CONFIG, ExpectedGradCAMConfig
from expected_gradcam.core import (
    BatchedPredictor,
    compute_second_moment_matrix,
    compute_cross_moment,
    solve_linear_system_robust,
    transform_weights,
)
from expected_gradcam.gpu.batched_ops import FullyBatchedExpectedGradients
from expected_gradcam.sampling import (
    sample_centered_baselines,
    BatchedPerturbationSampler,
)
from expected_gradcam.sampling.baseline import sample_from_provider
from expected_gradcam.types.results import SolverDiagnostics

if TYPE_CHECKING:
    from torch import Tensor
    from expected_gradcam.baselines.protocols import BaselineProvider
    from expected_gradcam.core.observer_manager import ObserverManager


class ExpectedGradCAMWeightComputer:
    """Computes Expected GradCAM optimal weights for SC methods.

    This class encapsulates the Expected GradCAM weight computation pipeline,
    making it reusable by Segment-Constrained methods. It computes optimal
    weights that minimize explanation infidelity:

        α* = M_I^{-1} * E[I * <I, φ>]

    where:
        - M_I = E[I * I^T] is the second moment matrix of perturbations
        - φ is the Expected Gradients attribution
        - I are perturbation vectors

    The pipeline consists of:
    1. Extract classifier head from model
    2. Sample perturbations (M samples, optionally data-aware)
    3. Sample centered baselines (N samples)
    4. Compute Expected Gradients via path integration (T steps)
    5. Build second moment matrix M_I and cross moment b
    6. Solve linear system for optimal weights
    7. Apply weight transformation

    Attributes:
        model: Target CNN model.
        target_layer: Layer to extract feature maps from.
        config: Configuration parameters (M, N, T, etc.).
        device: Computation device.
        baseline_dataset: Optional dataset for data-aware sampling.

    Example:
        >>> computer = ExpectedGradCAMWeightComputer(model, layer)
        >>> weights, diagnostics = computer.compute_weights(features, class_idx=243)
        >>> print(weights.shape)  # [K]
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        config: ExpectedGradCAMConfig | None = None,
        device: torch.device | str | None = None,
        classifier_head: nn.Module | None = None,
        observer_manager: "ObserverManager | None" = None,
    ) -> None:
        """Initialize the weight computer.

        Args:
            model: Target CNN model.
            target_layer: Layer to extract feature maps from.
            config: Configuration (uses DEFAULT_CONFIG if None).
                Set config.baseline_provider for data-aware sampling.
            device: Computation device. Auto-detects if None.
            classifier_head: Optional classifier head module. If None,
                will be auto-extracted using the architectures plugin.
                Provide this for custom/unsupported architectures.
            observer_manager: Optional observer manager for real-time callbacks.
        """
        self.model = model
        self.target_layer = target_layer

        # Auto-detect device
        if device is None:
            device = next(model.parameters()).device
        elif isinstance(device, str):
            device = torch.device(device)
        self.device = device

        # Configuration
        if config is None:
            config = DEFAULT_CONFIG
        self.config = config

        # Observer manager for callbacks
        self.observer_manager = observer_manager

        # Initialize baseline provider if configured
        self._baseline_provider: "BaselineProvider | None" = None
        if config.baseline_provider is not None:
            self._baseline_provider = config.baseline_provider
            # Initialize provider if not already initialized
            if not self._baseline_provider.is_initialized:
                self._baseline_provider.initialize(model, target_layer, self.device)

        # Lazy-load components
        self._classifier_head: nn.Module | None = classifier_head
        self._perturbation_sampler: Any = None

    @property
    def classifier_head(self) -> nn.Module:
        """Get classifier head (lazy load)."""
        if self._classifier_head is None:
            self._classifier_head = extract_classifier_head(
                self.model, self.target_layer
            )
        return self._classifier_head

    def _sample_perturbations(
        self,
        features: "Tensor",
        M: int,
    ) -> "Tensor":
        """Sample perturbation vectors I.

        Args:
            features: Feature maps [1, K, U, V].
            M: Number of perturbation samples.

        Returns:
            Perturbation samples [M, K] in [alpha_min, alpha_max].
        """
        K = features.shape[1]

        # Simple sampling (uniform in [alpha_min, alpha_max])
        # Data-aware perturbation sampling can be added by extending this
        alpha_min = self.config.alpha_min
        alpha_max = self.config.alpha_max

        if self.config.alpha_sampling == "uniform":
            I_samples = torch.rand(M, K, device=self.device)
            I_samples = I_samples * (alpha_max - alpha_min) + alpha_min
        else:
            # Linear sampling
            I_samples = torch.linspace(
                alpha_min, alpha_max, M, device=self.device
            )
            I_samples = I_samples.unsqueeze(1).expand(-1, K)

        return I_samples

    def compute_weights(
        self,
        features: "Tensor",
        class_idx: int,
        mask: "Tensor | None" = None,
    ) -> tuple["Tensor", SolverDiagnostics | None]:
        """Compute optimal weights for feature maps.

        This is the main entry point for weight computation. It executes
        the full Expected GradCAM pipeline:
        1. Sample perturbations and baselines
        2. Compute Expected Gradients
        3. Build second moment matrix
        4. Solve for optimal weights
        5. Apply weight transformation

        Args:
            features: Feature maps [1, K, U, V].
            class_idx: Target class index.
            mask: Optional segment mask [U, V] for per-segment computation.
                If provided, feature maps are masked before weight computation.

        Returns:
            Tuple of (optimal_weights [K], solver_diagnostics).
            If diagnostics are disabled in config, diagnostics will be None.

        Example:
            >>> weights, diagnostics = computer.compute_weights(features, class_idx=243)
            >>> print(f"Condition number: {diagnostics.condition_number:.2f}")
        """
        K = features.shape[1]

        # Apply mask if provided (for per-segment computation)
        if mask is not None:
            # Expand mask to feature dimensions: [U, V] -> [1, 1, U, V]
            mask_expanded = mask.unsqueeze(0).unsqueeze(0)
            features = features * mask_expanded

        # Create predictor
        predictor = BatchedPredictor(
            classifier_head=self.classifier_head,
            target_class=class_idx,
            feature_maps=features,
            use_compile=self.config.use_batching,
        )

        # Sample perturbations
        M = self.config.M
        N = self.config.N
        T = self.config.T

        I_samples = self._sample_perturbations(features, M)

        # Sample baselines for Expected Gradients
        if self._baseline_provider is not None:
            # Data-aware baselines from provider (recommended)
            D_samples = sample_from_provider(
                provider=self._baseline_provider,
                N=N,
                device=self.device,
                target_scale=self.config.baseline_scale,
            )
        else:
            # Gaussian baselines (fallback)
            D_samples = sample_centered_baselines(
                K=K,
                N=N,
                scale=self.config.baseline_scale,
                distribution=self.config.baseline_distribution,
                device=self.device,
            )

        # Compute Expected Gradients using FullyBatchedExpectedGradients
        z0 = torch.ones(K, device=self.device)
        eg = FullyBatchedExpectedGradients(T=T, N=N)

        # Determine target size for intermediate heatmaps
        # Features are [1, K, U, V], target should be image size (224, 224 typically)
        target_size = (224, 224)  # Standard ImageNet size

        phi_samples = eg.compute_batch(
            predictor_fn=predictor,
            z0=z0,
            I_batch=I_samples,
            D_samples=D_samples,
            use_amp=self.config.use_amp,
            # Pass observer callback parameters
            observer_manager=self.observer_manager,
            feature_maps=features if self.config.enable_computation_callbacks else None,
            target_size=target_size if self.config.enable_computation_callbacks else None,
            heatmap_checkpoint_interval=self.config.heatmap_checkpoint_interval,
        )

        # Compute second moment matrix M_I
        M_I = compute_second_moment_matrix(I_samples)

        # Compute cross moment b = E[I * <I, φ>]
        b = compute_cross_moment(I_samples, phi_samples)

        # Solve linear system: M_I @ α* = b
        # solve_linear_system_robust returns (alpha, SolverDiagnostics)
        alpha_raw, diagnostics = solve_linear_system_robust(
            M_I,
            b,
            method=self.config.solver_method,
            rcond=self.config.rank_threshold,
            regularization_eps=self.config.regularization_eps,
        )

        # Apply weight transformation
        alpha_transformed = transform_weights(
            alpha_raw,
            method=self.config.weight_transform,
            feature_maps=features,
            exponent=self.config.transform_exponent,
        )

        return alpha_transformed, diagnostics

    def compute_weights_batched_segments(
        self,
        features: "Tensor",
        class_idx: int,
        masks: "Tensor",
    ) -> "Tensor":
        """Compute weights for multiple segments in batch.

        This is for per-segment refinement mode. It computes optimal
        weights separately for each segment, which can be more accurate
        but is also more computationally expensive.

        Args:
            features: Feature maps [1, K, U, V].
            class_idx: Target class index.
            masks: Segment masks [N, U, V].

        Returns:
            Per-segment weights [N, K].

        Note:
            This method is computationally expensive as it runs the full
            Expected GradCAM pipeline for each segment. Use sparingly,
            preferably only for high-importance segments in hybrid mode.
        """
        N = masks.shape[0]
        K = features.shape[1]

        per_segment_weights = torch.zeros(N, K, device=self.device)

        for i in range(N):
            mask = masks[i]
            weights, _ = self.compute_weights(features, class_idx, mask=mask)
            per_segment_weights[i] = weights

        return per_segment_weights

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"ExpectedGradCAMWeightComputer("
            f"model={self.model.__class__.__name__}, "
            f"config=M={self.config.M}/N={self.config.N}/T={self.config.T}, "
            f"data_aware={self._baseline_provider is not None}, "
            f"device={self.device})"
        )
