"""Thread-safe observer management for computation callbacks.

This module provides the ObserverManager class that handles notification
of multiple computation observers in a thread-safe manner.

Example:
    >>> from expected_gradcam.core.observer_manager import ObserverManager
    >>> from expected_gradcam.core.callbacks import LoggingObserver
    >>>
    >>> manager = ObserverManager()
    >>> manager.add_observer(LoggingObserver())
    >>>
    >>> if manager.has_observers:
    ...     manager.notify_chunk_complete(result)
"""

from __future__ import annotations

import threading
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from expected_gradcam.core.callbacks import (
        ChunkResult,
        ComputationObserver,
        IntermediateHeatmap,
        SolverProgress,
    )


class ObserverManager:
    """Thread-safe manager for computation observers.

    Handles notification of multiple observers and provides zero-overhead
    when no observers are attached (returns early on has_observers check).

    Thread Safety:
        Uses a reentrant lock to protect observer list modifications.
        Notifications acquire the lock briefly to copy the observer list,
        then release it before calling observers (avoiding deadlocks).

    Example:
        >>> manager = ObserverManager()
        >>> manager.add_observer(my_observer)
        >>>
        >>> # In computation loop:
        >>> if manager.has_observers:
        ...     manager.notify_chunk_complete(result)
        >>>
        >>> manager.remove_observer(my_observer)
    """

    def __init__(self) -> None:
        """Initialize the observer manager."""
        self._observers: list[ComputationObserver] = []
        self._lock = threading.RLock()
        self._enabled = True

    @property
    def has_observers(self) -> bool:
        """Check if any observers are registered.

        This is a fast-path check to avoid overhead when no observers
        are registered. Call this before preparing notification data.

        Returns:
            True if there is at least one observer registered.
        """
        return self._enabled and len(self._observers) > 0

    @property
    def enabled(self) -> bool:
        """Check if the observer manager is enabled."""
        return self._enabled

    @enabled.setter
    def enabled(self, value: bool) -> None:
        """Enable or disable observer notifications.

        When disabled, has_observers returns False even if observers
        are registered, effectively muting all notifications.
        """
        self._enabled = value

    @property
    def observer_count(self) -> int:
        """Get the number of registered observers."""
        return len(self._observers)

    def add_observer(self, observer: ComputationObserver) -> None:
        """Register a computation observer.

        The observer will receive notifications for all computation events.
        Duplicate observers are not added.

        Args:
            observer: Observer implementing ComputationObserver protocol.
        """
        with self._lock:
            if observer not in self._observers:
                self._observers.append(observer)

    def remove_observer(self, observer: ComputationObserver) -> None:
        """Unregister a computation observer.

        Args:
            observer: Observer to remove.
        """
        with self._lock:
            if observer in self._observers:
                self._observers.remove(observer)

    def clear_observers(self) -> None:
        """Remove all registered observers."""
        with self._lock:
            self._observers.clear()

    def notify_chunk_complete(self, result: ChunkResult) -> None:
        """Notify all observers of chunk completion.

        Called after each M-chunk in the Expected Gradients computation.

        Args:
            result: ChunkResult with partial computation state.
        """
        if not self.has_observers:
            return

        # Copy observer list while holding lock, then notify outside lock
        with self._lock:
            observers = list(self._observers)

        for observer in observers:
            try:
                observer.on_chunk_complete(result)
            except Exception:
                # Don't let observer errors break computation
                pass

    def notify_intermediate_heatmap(self, heatmap: IntermediateHeatmap) -> None:
        """Notify all observers of intermediate heatmap generation.

        Called at configurable checkpoints during computation.

        Args:
            heatmap: IntermediateHeatmap with current visualization state.
        """
        if not self.has_observers:
            return

        with self._lock:
            observers = list(self._observers)

        for observer in observers:
            try:
                observer.on_intermediate_heatmap(heatmap)
            except Exception:
                pass

    def notify_solver_progress(self, progress: SolverProgress) -> None:
        """Notify all observers of solver progress.

        Called during the linear system solve phase.

        Args:
            progress: SolverProgress with solver diagnostics.
        """
        if not self.has_observers:
            return

        with self._lock:
            observers = list(self._observers)

        for observer in observers:
            try:
                observer.on_solver_progress(progress)
            except Exception:
                pass

    def __enter__(self) -> "ObserverManager":
        """Context manager entry."""
        return self

    def __exit__(self, *args) -> None:
        """Context manager exit - clears all observers."""
        self.clear_observers()
