"""Synchronous probe implementation.

This module provides a blocking probe that writes data immediately
when callbacks are invoked. Best for debugging and small datasets.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from expected_gradcam.probes.storage import StorageBackend

if TYPE_CHECKING:
    from torch import Tensor


class SyncProbe:
    """Synchronous probe that writes immediately.

    This probe writes data to storage as soon as callbacks are invoked.
    Use this for:
    - Debugging (immediate feedback)
    - Small datasets
    - When latency is not critical

    For high-throughput scenarios, use AsyncProbe instead.

    Attributes:
        storage: Storage backend for persisting data.
        include_perturbations: Whether to store perturbation samples.
        include_attributions: Whether to store attribution samples.

    Example:
        >>> storage = NPZStorage("output/")
        >>> probe = SyncProbe(storage)
        >>>
        >>> # Use with ExpectedGradCAM (hypothetical)
        >>> result = egcam.compute(image, class_idx, probe=probe)
        >>> probe.flush()
    """

    def __init__(
        self,
        storage: StorageBackend,
        include_perturbations: bool = True,
        include_attributions: bool = True,
        include_moment_matrix: bool = True,
        include_heatmaps: bool = True,
    ) -> None:
        """Initialize synchronous probe.

        Args:
            storage: Storage backend for persisting data.
            include_perturbations: Whether to store perturbation samples.
            include_attributions: Whether to store attribution samples.
            include_moment_matrix: Whether to store second moment matrix.
            include_heatmaps: Whether to store heatmaps.
        """
        self.storage = storage
        self.include_perturbations = include_perturbations
        self.include_attributions = include_attributions
        self.include_moment_matrix = include_moment_matrix
        self.include_heatmaps = include_heatmaps

        # Current sample accumulator
        self._current_sample_id: str | None = None
        self._current_data: dict[str, Any] = {}

    def _get_sample_id(self, metadata: dict[str, Any]) -> str:
        """Extract or generate sample ID from metadata."""
        return metadata.get("sample_id", f"sample_{len(self.storage.list_samples()):06d}")

    def _ensure_sample_started(self, metadata: dict[str, Any]) -> None:
        """Ensure we have a sample ID for accumulation."""
        sample_id = self._get_sample_id(metadata)
        if self._current_sample_id != sample_id:
            # Save previous sample if exists
            if self._current_sample_id and self._current_data:
                self.storage.save(self._current_sample_id, self._current_data)

            # Start new sample
            self._current_sample_id = sample_id
            self._current_data = {"sample_id": sample_id}

    def on_perturbation_sampled(
        self,
        I_samples: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Store perturbation samples.

        Args:
            I_samples: Perturbation samples [M, K].
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        if self.include_perturbations:
            self._current_data["perturbations"] = I_samples.detach().cpu().numpy()
            self._current_data["M"] = I_samples.shape[0]
            self._current_data["K"] = I_samples.shape[1]

    def on_attribution_computed(
        self,
        phi_samples: "Tensor",
        M_I: "Tensor",
        eigenvalues: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Store attribution data.

        Args:
            phi_samples: Attribution samples [M, K].
            M_I: Second moment matrix [K, K].
            eigenvalues: Eigenvalues of M_I [K].
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        if self.include_attributions:
            self._current_data["attributions"] = phi_samples.detach().cpu().numpy()

        if self.include_moment_matrix:
            self._current_data["second_moment"] = M_I.detach().cpu().numpy()
            self._current_data["eigenvalues"] = eigenvalues.detach().cpu().numpy()

            # Compute condition number
            eig_np = eigenvalues.detach().cpu().numpy()
            self._current_data["condition_number"] = float(
                eig_np.max() / (eig_np.min() + 1e-10)
            )
            self._current_data["effective_rank"] = int((eig_np > 1e-6).sum())

    def on_weights_computed(
        self,
        alpha_opt: "Tensor",
        solver_diagnostics: dict[str, Any] | None,
        metadata: dict[str, Any],
    ) -> None:
        """Store optimal weights.

        Args:
            alpha_opt: Optimal weights [K].
            solver_diagnostics: Solver information.
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        self._current_data["optimal_weights"] = alpha_opt.detach().cpu().numpy()

        if solver_diagnostics:
            # Convert tensor values to numpy
            diag_clean = {}
            for key, value in solver_diagnostics.items():
                if hasattr(value, "cpu"):
                    diag_clean[key] = value.detach().cpu().numpy()
                else:
                    diag_clean[key] = value
            self._current_data["solver_diagnostics"] = diag_clean

    def on_heatmap_generated(
        self,
        heatmap: "Tensor",
        raw_heatmap: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Store heatmap data.

        Args:
            heatmap: Final normalized heatmap [H, W].
            raw_heatmap: Pre-normalization heatmap [H, W].
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        if self.include_heatmaps:
            self._current_data["heatmap"] = heatmap.detach().cpu().numpy()
            self._current_data["raw_heatmap"] = raw_heatmap.detach().cpu().numpy()

    def flush(self) -> None:
        """Save any accumulated data.

        Call this after computation to ensure all data is persisted.
        """
        if self._current_sample_id and self._current_data:
            self.storage.save(self._current_sample_id, self._current_data)
            self._current_sample_id = None
            self._current_data = {}

    def close(self) -> None:
        """Close probe and storage."""
        self.flush()
        self.storage.close()

    def __enter__(self) -> "SyncProbe":
        """Context manager entry."""
        return self

    def __exit__(self, *args: Any) -> None:
        """Context manager exit."""
        self.close()
