"""Intermediate value collector for research probes.

This module provides a dataclass for organizing collected intermediate
values from E-GradCAM computation.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    import numpy as np
    from numpy.typing import NDArray


@dataclass
class IntermediateCollector:
    """Collector for intermediate E-GradCAM values.

    This dataclass organizes all intermediate values that can be collected
    during E-GradCAM computation. It provides a structured way to store
    and access collected data.

    All arrays are stored as numpy arrays to ensure they can be serialized
    to storage backends.

    Attributes:
        sample_id: Unique identifier for this computation.
        config: Configuration used for computation.
        perturbations: Perturbation samples [M, K].
        attributions: Attribution samples [M, K].
        second_moment: Second moment matrix [K, K].
        eigenvalues: Eigenvalues of second moment matrix [K].
        optimal_weights: Optimal channel weights [K].
        solver_diagnostics: Diagnostics from linear solver.
        raw_heatmap: Pre-normalization heatmap [H, W].
        heatmap: Final normalized heatmap [H, W].
        metadata: Additional metadata from computation.

    Example:
        >>> collector = IntermediateCollector(sample_id="sample_001")
        >>> collector.perturbations = I_samples.cpu().numpy()
        >>> collector.attributions = phi_samples.cpu().numpy()
        >>> collector.save("output/sample_001.npz")
    """

    sample_id: str
    config: dict[str, Any] = field(default_factory=dict)

    # Perturbation stage
    perturbations: "NDArray[np.floating] | None" = None

    # Attribution stage
    attributions: "NDArray[np.floating] | None" = None
    second_moment: "NDArray[np.floating] | None" = None
    eigenvalues: "NDArray[np.floating] | None" = None

    # Weight solving stage
    optimal_weights: "NDArray[np.floating] | None" = None
    solver_diagnostics: dict[str, Any] = field(default_factory=dict)

    # Heatmap stage
    raw_heatmap: "NDArray[np.floating] | None" = None
    heatmap: "NDArray[np.floating] | None" = None

    # Additional metadata
    metadata: dict[str, Any] = field(default_factory=dict)

    def is_complete(self) -> bool:
        """Check if all stages have been collected.

        Returns:
            True if all intermediate values are present.
        """
        return all(
            [
                self.perturbations is not None,
                self.attributions is not None,
                self.second_moment is not None,
                self.eigenvalues is not None,
                self.optimal_weights is not None,
                self.heatmap is not None,
            ]
        )

    def save(self, path: str) -> None:
        """Save collector to NPZ file.

        Args:
            path: Path to save file.
        """
        import numpy as np

        data = {
            "sample_id": self.sample_id,
        }

        # Add arrays
        if self.perturbations is not None:
            data["perturbations"] = self.perturbations
        if self.attributions is not None:
            data["attributions"] = self.attributions
        if self.second_moment is not None:
            data["second_moment"] = self.second_moment
        if self.eigenvalues is not None:
            data["eigenvalues"] = self.eigenvalues
        if self.optimal_weights is not None:
            data["optimal_weights"] = self.optimal_weights
        if self.raw_heatmap is not None:
            data["raw_heatmap"] = self.raw_heatmap
        if self.heatmap is not None:
            data["heatmap"] = self.heatmap

        np.savez(path, **data)

    @classmethod
    def load(cls, path: str) -> "IntermediateCollector":
        """Load collector from NPZ file.

        Args:
            path: Path to NPZ file.

        Returns:
            Loaded IntermediateCollector.
        """
        import numpy as np

        data = np.load(path, allow_pickle=True)

        collector = cls(sample_id=str(data.get("sample_id", "unknown")))

        if "perturbations" in data:
            collector.perturbations = data["perturbations"]
        if "attributions" in data:
            collector.attributions = data["attributions"]
        if "second_moment" in data:
            collector.second_moment = data["second_moment"]
        if "eigenvalues" in data:
            collector.eigenvalues = data["eigenvalues"]
        if "optimal_weights" in data:
            collector.optimal_weights = data["optimal_weights"]
        if "raw_heatmap" in data:
            collector.raw_heatmap = data["raw_heatmap"]
        if "heatmap" in data:
            collector.heatmap = data["heatmap"]

        return collector

    def get_summary(self) -> dict[str, Any]:
        """Get summary statistics of collected data.

        Returns:
            Dictionary with summary statistics.
        """
        import numpy as np

        summary = {
            "sample_id": self.sample_id,
            "is_complete": self.is_complete(),
        }

        if self.perturbations is not None:
            summary["M"] = self.perturbations.shape[0]
            summary["K"] = self.perturbations.shape[1]
            summary["perturbation_std"] = float(np.std(self.perturbations))

        if self.eigenvalues is not None:
            summary["condition_number"] = float(
                self.eigenvalues.max() / (self.eigenvalues.min() + 1e-10)
            )
            summary["effective_rank"] = int((self.eigenvalues > 1e-6).sum())

        if self.optimal_weights is not None:
            summary["weight_mean"] = float(np.mean(self.optimal_weights))
            summary["weight_std"] = float(np.std(self.optimal_weights))
            summary["weight_max"] = float(np.max(self.optimal_weights))
            summary["weight_min"] = float(np.min(self.optimal_weights))

        if self.solver_diagnostics:
            summary["solver_diagnostics"] = self.solver_diagnostics

        return summary
