"""Asynchronous probe implementation.

This module provides a queue-based probe that stores data asynchronously
using a background thread. Best for high-throughput scenarios.
"""

from __future__ import annotations

import queue
import threading
from typing import TYPE_CHECKING, Any

from expected_gradcam.probes.storage import StorageBackend

if TYPE_CHECKING:
    from torch import Tensor


class AsyncProbe:
    """Asynchronous probe with background storage.

    This probe queues data for storage by a background thread, minimizing
    overhead during computation. Use this for:
    - High-throughput scenarios
    - Large datasets
    - When computation speed is critical

    Data is queued as numpy arrays and written by a background thread.
    Call flush() to ensure all queued data is written before exiting.

    Attributes:
        storage: Storage backend for persisting data.
        max_queue_size: Maximum items in the queue before blocking.

    Example:
        >>> storage = NPZStorage("output/")
        >>> probe = AsyncProbe(storage)
        >>>
        >>> # Process many samples
        >>> for image in dataset:
        ...     result = egcam.compute(image, probe=probe)
        >>>
        >>> # Wait for all writes to complete
        >>> probe.flush()
        >>> probe.close()
    """

    def __init__(
        self,
        storage: StorageBackend,
        max_queue_size: int = 100,
        include_perturbations: bool = True,
        include_attributions: bool = True,
        include_moment_matrix: bool = True,
        include_heatmaps: bool = True,
    ) -> None:
        """Initialize asynchronous probe.

        Args:
            storage: Storage backend for persisting data.
            max_queue_size: Maximum items in queue before blocking.
            include_perturbations: Whether to store perturbation samples.
            include_attributions: Whether to store attribution samples.
            include_moment_matrix: Whether to store second moment matrix.
            include_heatmaps: Whether to store heatmaps.
        """
        self.storage = storage
        self.max_queue_size = max_queue_size
        self.include_perturbations = include_perturbations
        self.include_attributions = include_attributions
        self.include_moment_matrix = include_moment_matrix
        self.include_heatmaps = include_heatmaps

        # Current sample accumulator
        self._current_sample_id: str | None = None
        self._current_data: dict[str, Any] = {}
        self._lock = threading.Lock()

        # Queue and background thread
        self._queue: queue.Queue[tuple[str, dict[str, Any]] | None] = queue.Queue(
            maxsize=max_queue_size
        )
        self._thread: threading.Thread | None = None
        self._running = False

        # Start background thread
        self._start_worker()

    def _start_worker(self) -> None:
        """Start background worker thread."""
        self._running = True
        self._thread = threading.Thread(target=self._worker_loop, daemon=True)
        self._thread.start()

    def _worker_loop(self) -> None:
        """Background worker that processes queued writes."""
        while self._running or not self._queue.empty():
            try:
                item = self._queue.get(timeout=0.1)
                if item is None:
                    # Shutdown signal
                    break

                sample_id, data = item
                self.storage.save(sample_id, data)
                self._queue.task_done()

            except queue.Empty:
                continue
            except Exception as e:
                # Log error but continue processing
                print(f"AsyncProbe error: {e}")
                self._queue.task_done()

    def _get_sample_id(self, metadata: dict[str, Any]) -> str:
        """Extract or generate sample ID from metadata."""
        return metadata.get(
            "sample_id", f"sample_{self._queue.qsize():06d}"
        )

    def _ensure_sample_started(self, metadata: dict[str, Any]) -> None:
        """Ensure we have a sample ID for accumulation."""
        with self._lock:
            sample_id = self._get_sample_id(metadata)
            if self._current_sample_id != sample_id:
                # Queue previous sample if exists
                if self._current_sample_id and self._current_data:
                    self._queue.put((self._current_sample_id, self._current_data))

                # Start new sample
                self._current_sample_id = sample_id
                self._current_data = {"sample_id": sample_id}

    def on_perturbation_sampled(
        self,
        I_samples: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Queue perturbation samples for storage.

        Args:
            I_samples: Perturbation samples [M, K].
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        if self.include_perturbations:
            # Convert to numpy immediately to avoid keeping GPU tensors
            with self._lock:
                self._current_data["perturbations"] = I_samples.detach().cpu().numpy()
                self._current_data["M"] = I_samples.shape[0]
                self._current_data["K"] = I_samples.shape[1]

    def on_attribution_computed(
        self,
        phi_samples: "Tensor",
        M_I: "Tensor",
        eigenvalues: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Queue attribution data for storage.

        Args:
            phi_samples: Attribution samples [M, K].
            M_I: Second moment matrix [K, K].
            eigenvalues: Eigenvalues of M_I [K].
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        with self._lock:
            if self.include_attributions:
                self._current_data["attributions"] = phi_samples.detach().cpu().numpy()

            if self.include_moment_matrix:
                self._current_data["second_moment"] = M_I.detach().cpu().numpy()
                self._current_data["eigenvalues"] = eigenvalues.detach().cpu().numpy()

                eig_np = eigenvalues.detach().cpu().numpy()
                self._current_data["condition_number"] = float(
                    eig_np.max() / (eig_np.min() + 1e-10)
                )
                self._current_data["effective_rank"] = int((eig_np > 1e-6).sum())

    def on_weights_computed(
        self,
        alpha_opt: "Tensor",
        solver_diagnostics: dict[str, Any] | None,
        metadata: dict[str, Any],
    ) -> None:
        """Queue optimal weights for storage.

        Args:
            alpha_opt: Optimal weights [K].
            solver_diagnostics: Solver information.
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        with self._lock:
            self._current_data["optimal_weights"] = alpha_opt.detach().cpu().numpy()

            if solver_diagnostics:
                diag_clean = {}
                for key, value in solver_diagnostics.items():
                    if hasattr(value, "cpu"):
                        diag_clean[key] = value.detach().cpu().numpy()
                    else:
                        diag_clean[key] = value
                self._current_data["solver_diagnostics"] = diag_clean

    def on_heatmap_generated(
        self,
        heatmap: "Tensor",
        raw_heatmap: "Tensor",
        metadata: dict[str, Any],
    ) -> None:
        """Queue heatmap data for storage.

        Args:
            heatmap: Final normalized heatmap [H, W].
            raw_heatmap: Pre-normalization heatmap [H, W].
            metadata: Additional context.
        """
        self._ensure_sample_started(metadata)

        if self.include_heatmaps:
            with self._lock:
                self._current_data["heatmap"] = heatmap.detach().cpu().numpy()
                self._current_data["raw_heatmap"] = raw_heatmap.detach().cpu().numpy()

    def flush(self) -> None:
        """Wait for all queued writes to complete.

        Call this before exiting to ensure all data is persisted.
        """
        # Queue current sample if exists
        with self._lock:
            if self._current_sample_id and self._current_data:
                self._queue.put((self._current_sample_id, self._current_data))
                self._current_sample_id = None
                self._current_data = {}

        # Wait for queue to empty
        self._queue.join()

    def close(self) -> None:
        """Stop background thread and close storage."""
        self.flush()

        # Signal shutdown
        self._running = False
        self._queue.put(None)

        # Wait for thread to finish
        if self._thread is not None:
            self._thread.join(timeout=5.0)

        self.storage.close()

    @property
    def pending_count(self) -> int:
        """Number of samples pending write."""
        return self._queue.qsize()

    def __enter__(self) -> "AsyncProbe":
        """Context manager entry."""
        return self

    def __exit__(self, *args: Any) -> None:
        """Context manager exit."""
        self.close()
