"""Computation observer protocol for real-time visualization.

This module defines the protocol for observing Expected GradCAM computation
progress, enabling real-time visualization and monitoring.

Example:
    >>> from expected_gradcam.core.callbacks import ComputationObserver
    >>>
    >>> class MyObserver:
    ...     def on_chunk_complete(self, result: ChunkResult) -> None:
    ...         print(f"Chunk {result.chunk_idx}/{result.total_chunks}")
    ...     def on_intermediate_heatmap(self, heatmap: IntermediateHeatmap) -> None:
    ...         pass
    ...     def on_solver_progress(self, progress: SolverProgress) -> None:
    ...         print(f"Condition: {progress.condition_number:.2e}")
    >>>
    >>> isinstance(MyObserver(), ComputationObserver)
    True
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, runtime_checkable

if TYPE_CHECKING:
    from torch import Tensor


@dataclass(frozen=True)
class ChunkResult:
    """Result from a single chunk of perturbation computation.

    Emitted after each M-chunk in the Expected Gradients computation.
    Provides partial estimates of M_I and optimal weights for real-time
    visualization.

    Attributes:
        chunk_idx: Index of the current chunk (0-based).
        total_chunks: Total number of chunks to process.
        samples_processed: Cumulative samples processed so far.
        total_samples: Total M samples to process.
        partial_M_I: Current second moment matrix estimate [K, K].
        partial_b: Current cross moment estimate [K].
        partial_alpha: Intermediate optimal weights [K] if computable.
        condition_number: Current condition number estimate.
        elapsed_seconds: Elapsed time since computation start.
    """

    chunk_idx: int
    total_chunks: int
    samples_processed: int
    total_samples: int
    partial_M_I: "Tensor"
    partial_b: "Tensor | None" = None
    partial_alpha: "Tensor | None" = None
    condition_number: float | None = None
    elapsed_seconds: float = 0.0

    @property
    def progress(self) -> float:
        """Progress as a fraction [0, 1]."""
        return self.samples_processed / max(self.total_samples, 1)


@dataclass(frozen=True)
class IntermediateHeatmap:
    """Intermediate heatmap generated during computation.

    Emitted at configurable checkpoints to enable progressive
    visualization of heatmap evolution.

    Attributes:
        checkpoint_idx: Index of this checkpoint.
        samples_processed: Samples used to generate this heatmap.
        total_samples: Total samples in computation.
        coarse_heatmap: Heatmap at feature resolution [U, V].
        full_heatmap: Upsampled heatmap at image resolution [H, W].
        weights: Optimal weights used [K].
        condition_number: Condition number at this checkpoint.
        infidelity_estimate: Running infidelity estimate (optional).
    """

    checkpoint_idx: int
    samples_processed: int
    total_samples: int
    coarse_heatmap: "Tensor"
    full_heatmap: "Tensor"
    weights: "Tensor"
    condition_number: float
    infidelity_estimate: float | None = None

    @property
    def progress(self) -> float:
        """Progress as a fraction [0, 1]."""
        return self.samples_processed / max(self.total_samples, 1)


@dataclass(frozen=True)
class SolverProgress:
    """Progress during the linear system solve phase.

    Provides insight into the conditioning of M_I and solver behavior.

    Attributes:
        eigenvalue_min: Smallest positive eigenvalue.
        eigenvalue_max: Largest eigenvalue.
        condition_number: Ratio of max to min eigenvalue.
        effective_rank: Number of eigenvalues above threshold.
        total_channels: Total number of channels K.
        method: Solver method being used.
        regularization_eps: Regularization epsilon if applicable.
    """

    eigenvalue_min: float
    eigenvalue_max: float
    condition_number: float
    effective_rank: int
    total_channels: int
    method: str
    regularization_eps: float | None = None

    @property
    def rank_percentage(self) -> float:
        """Effective rank as percentage of total channels."""
        return 100.0 * self.effective_rank / max(self.total_channels, 1)

    @property
    def is_well_conditioned(self) -> bool:
        """Check if condition number is acceptable (< 1e6)."""
        return self.condition_number < 1e6


@dataclass(frozen=True)
class MetricsSnapshot:
    """Complete metrics snapshot for visualization.

    Provides a comprehensive set of metrics at a single point in time,
    combining solver diagnostics, infidelity, and weight metrics.

    This is emitted at configurable intervals during computation and
    used by the visualizer to update all metric plots.

    Attributes:
        checkpoint_idx: Index of this checkpoint.
        samples_processed: Cumulative samples processed.
        total_samples: Total samples in computation.
        elapsed_seconds: Time since computation start.
        condition_number: κ(M_I) = λ_max / λ_min.
        effective_rank: Number of significant eigenvalues.
        residual_norm: ||M_I @ α - b|| / ||b||.
        eigenvalue_min: Smallest positive eigenvalue.
        eigenvalue_max: Largest eigenvalue.
        infidelity: Feature-space infidelity (zero-cost).
        weight_norm: ||α||₂.
        total_channels: Total number of channels K.
    """

    checkpoint_idx: int
    samples_processed: int
    total_samples: int
    elapsed_seconds: float

    # Solver metrics
    condition_number: float | None = None
    effective_rank: int | None = None
    residual_norm: float | None = None
    eigenvalue_min: float | None = None
    eigenvalue_max: float | None = None

    # Infidelity (computed from intermediates)
    infidelity: float | None = None

    # Weight metrics
    weight_norm: float | None = None

    # Dimensionality info
    total_channels: int | None = None

    @property
    def progress(self) -> float:
        """Progress as a fraction [0, 1]."""
        return self.samples_processed / max(self.total_samples, 1)

    @property
    def is_well_conditioned(self) -> bool:
        """Check if condition number is acceptable (< 1e6)."""
        if self.condition_number is None:
            return True
        return self.condition_number < 1e6

    @property
    def rank_percentage(self) -> float | None:
        """Effective rank as percentage of total channels."""
        if self.effective_rank is None or self.total_channels is None:
            return None
        return 100.0 * self.effective_rank / max(self.total_channels, 1)

    def to_dict(self) -> dict[str, float | int | None]:
        """Convert to dictionary for serialization."""
        return {
            "checkpoint_idx": self.checkpoint_idx,
            "samples_processed": self.samples_processed,
            "total_samples": self.total_samples,
            "elapsed_seconds": self.elapsed_seconds,
            "progress": self.progress,
            "condition_number": self.condition_number,
            "effective_rank": self.effective_rank,
            "residual_norm": self.residual_norm,
            "eigenvalue_min": self.eigenvalue_min,
            "eigenvalue_max": self.eigenvalue_max,
            "infidelity": self.infidelity,
            "weight_norm": self.weight_norm,
            "total_channels": self.total_channels,
        }


@runtime_checkable
class ComputationObserver(Protocol):
    """Protocol for observing Expected GradCAM computation progress.

    Observers receive updates about computational progress and intermediate
    results, enabling real-time visualization and monitoring.

    The protocol uses @runtime_checkable to enable isinstance() checks.

    All methods are called in a thread-safe manner by ObserverManager.
    Implementations should be efficient to avoid slowing computation.

    Implementations:
    - VisualizationObserver: Real-time matplotlib visualization
    - LoggingObserver: Progress logging
    - DataCollector: Collects intermediate results for analysis
    """

    def on_chunk_complete(self, result: ChunkResult) -> None:
        """Called after each M-chunk of perturbations is processed.

        This is the primary hook for real-time updates during the
        M x N x T forward/backward computation (the main bottleneck).

        Args:
            result: ChunkResult with partial M_I and optional weights.
        """
        ...

    def on_intermediate_heatmap(self, heatmap: IntermediateHeatmap) -> None:
        """Called when an intermediate heatmap is generated.

        Triggered at configurable checkpoints during computation.
        The heatmap reflects the current optimal weights estimate.

        Args:
            heatmap: IntermediateHeatmap with coarse and full resolution maps.
        """
        ...

    def on_solver_progress(self, progress: SolverProgress) -> None:
        """Called with solver diagnostics during linear solve.

        Provides insight into the conditioning of M_I and solver behavior.

        Args:
            progress: SolverProgress with eigenvalue and rank info.
        """
        ...

    def on_metrics_snapshot(self, snapshot: MetricsSnapshot) -> None:
        """Called with comprehensive metrics update.

        Emitted at configurable intervals with all available metrics.
        This is the primary hook for multi-metric visualization.

        Args:
            snapshot: MetricsSnapshot with all metrics at current point.
        """
        ...


class NullComputationObserver:
    """No-op observer that ignores all computation updates.

    Useful as a default when no observation is needed, or for testing.

    Example:
        >>> observer = NullComputationObserver()
        >>> observer.on_chunk_complete(result)  # Does nothing
    """

    def on_chunk_complete(self, result: ChunkResult) -> None:
        """Ignore chunk completion."""
        pass

    def on_intermediate_heatmap(self, heatmap: IntermediateHeatmap) -> None:
        """Ignore intermediate heatmap."""
        pass

    def on_solver_progress(self, progress: SolverProgress) -> None:
        """Ignore solver progress."""
        pass

    def on_metrics_snapshot(self, snapshot: MetricsSnapshot) -> None:
        """Ignore metrics snapshot."""
        pass


class LoggingObserver:
    """Observer that logs computation progress.

    Useful for debugging and monitoring in non-interactive contexts.

    Example:
        >>> import logging
        >>> observer = LoggingObserver(log_interval=5)
        >>> egcam.add_observer(observer)
        >>> result = egcam.generate(image)
        # Logs: "Chunk 5/20 (25.0%): cond=1.23e+03, 2.5s"
    """

    def __init__(
        self,
        logger: "logging.Logger | None" = None,
        log_interval: int = 5,
    ) -> None:
        """Initialize logging observer.

        Args:
            logger: Logger to use. If None, uses module logger.
            log_interval: Log every N chunks.
        """
        import logging as _logging

        self.logger = logger or _logging.getLogger(__name__)
        self.log_interval = log_interval
        self._chunk_count = 0

    def on_chunk_complete(self, result: ChunkResult) -> None:
        """Log chunk completion at configured intervals."""
        self._chunk_count += 1
        if self._chunk_count % self.log_interval == 0 or result.chunk_idx == 0:
            cond_str = (
                f", cond={result.condition_number:.2e}"
                if result.condition_number
                else ""
            )
            self.logger.info(
                f"Chunk {result.chunk_idx + 1}/{result.total_chunks} "
                f"({result.progress * 100:.1f}%){cond_str}, "
                f"{result.elapsed_seconds:.2f}s"
            )

    def on_intermediate_heatmap(self, heatmap: IntermediateHeatmap) -> None:
        """Log intermediate heatmap generation."""
        self.logger.debug(
            f"Heatmap checkpoint {heatmap.checkpoint_idx}: "
            f"{heatmap.progress * 100:.1f}%, "
            f"cond={heatmap.condition_number:.2e}"
        )

    def on_solver_progress(self, progress: SolverProgress) -> None:
        """Log solver progress."""
        self.logger.info(
            f"Solver: method={progress.method}, "
            f"rank={progress.effective_rank}/{progress.total_channels} "
            f"({progress.rank_percentage:.1f}%), "
            f"cond={progress.condition_number:.2e}"
        )

    def on_metrics_snapshot(self, snapshot: MetricsSnapshot) -> None:
        """Log comprehensive metrics snapshot."""
        parts = [f"Metrics [{snapshot.progress * 100:.1f}%]"]

        if snapshot.condition_number is not None:
            parts.append(f"cond={snapshot.condition_number:.2e}")
        if snapshot.infidelity is not None:
            parts.append(f"infid={snapshot.infidelity:.4f}")
        if snapshot.effective_rank is not None and snapshot.total_channels is not None:
            parts.append(f"rank={snapshot.effective_rank}/{snapshot.total_channels}")
        if snapshot.residual_norm is not None:
            parts.append(f"resid={snapshot.residual_norm:.2e}")
        if snapshot.weight_norm is not None:
            parts.append(f"||α||={snapshot.weight_norm:.2f}")

        self.logger.info(", ".join(parts))


class DataCollectorObserver:
    """Observer that collects intermediate data for analysis.

    Useful for research and debugging, collecting all intermediate
    heatmaps and metrics for later analysis.

    Example:
        >>> collector = DataCollectorObserver()
        >>> egcam.add_observer(collector)
        >>> result = egcam.generate(image)
        >>> print(f"Collected {len(collector.heatmaps)} heatmaps")
        >>> print(f"Condition numbers: {collector.condition_numbers}")
        >>> print(f"Infidelities: {collector.infidelities}")
    """

    def __init__(self, max_heatmaps: int = 100) -> None:
        """Initialize data collector.

        Args:
            max_heatmaps: Maximum number of heatmaps to store.
        """
        self.max_heatmaps = max_heatmaps
        self.heatmaps: list[IntermediateHeatmap] = []
        self.chunk_results: list[ChunkResult] = []
        self.condition_numbers: list[float] = []
        self.infidelities: list[float] = []
        self.effective_ranks: list[int] = []
        self.residual_norms: list[float] = []
        self.weight_norms: list[float] = []
        self.metrics_snapshots: list[MetricsSnapshot] = []
        self.solver_progress: SolverProgress | None = None

    def on_chunk_complete(self, result: ChunkResult) -> None:
        """Collect chunk result."""
        self.chunk_results.append(result)
        if result.condition_number is not None:
            self.condition_numbers.append(result.condition_number)

    def on_intermediate_heatmap(self, heatmap: IntermediateHeatmap) -> None:
        """Collect intermediate heatmap."""
        if len(self.heatmaps) < self.max_heatmaps:
            self.heatmaps.append(heatmap)
        # Also collect infidelity if available
        if heatmap.infidelity_estimate is not None:
            self.infidelities.append(heatmap.infidelity_estimate)

    def on_solver_progress(self, progress: SolverProgress) -> None:
        """Store solver progress."""
        self.solver_progress = progress
        self.effective_ranks.append(progress.effective_rank)

    def on_metrics_snapshot(self, snapshot: MetricsSnapshot) -> None:
        """Collect comprehensive metrics snapshot."""
        self.metrics_snapshots.append(snapshot)

        # Also update individual lists for convenience
        if snapshot.condition_number is not None:
            self.condition_numbers.append(snapshot.condition_number)
        if snapshot.infidelity is not None:
            self.infidelities.append(snapshot.infidelity)
        if snapshot.effective_rank is not None:
            self.effective_ranks.append(snapshot.effective_rank)
        if snapshot.residual_norm is not None:
            self.residual_norms.append(snapshot.residual_norm)
        if snapshot.weight_norm is not None:
            self.weight_norms.append(snapshot.weight_norm)

    def clear(self) -> None:
        """Clear all collected data."""
        self.heatmaps.clear()
        self.chunk_results.clear()
        self.condition_numbers.clear()
        self.infidelities.clear()
        self.effective_ranks.clear()
        self.residual_norms.clear()
        self.weight_norms.clear()
        self.metrics_snapshots.clear()
        self.solver_progress = None
