"""Result dataclasses for Expected GradCAM outputs."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import torch


if TYPE_CHECKING:
    from torch import Tensor


@dataclass(frozen=True)
class SolverDiagnostics:
    """Diagnostics from the linear system solver.

    These diagnostics help understand the numerical properties
    of the optimization and can guide parameter tuning.

    Attributes:
        method: Solver method used.
        condition_number: Condition number of M_I (lower is better).
        effective_rank: Effective rank of M_I (ideally equals K).
        K: Total number of channels.
        regularization_eps: Regularization epsilon used (if applicable).
        residual_norm: ||M_I @ α - b|| / ||b|| (lower is better).
        eigenvalue_min: Smallest eigenvalue of M_I.
        eigenvalue_max: Largest eigenvalue of M_I.
    """

    method: str
    condition_number: float
    effective_rank: int
    K: int
    regularization_eps: float | None = None
    residual_norm: float | None = None
    eigenvalue_min: float | None = None
    eigenvalue_max: float | None = None

    @property
    def is_full_rank(self) -> bool:
        """Check if M_I has full rank."""
        return self.effective_rank == self.K

    @property
    def is_well_conditioned(self) -> bool:
        """Check if M_I is well-conditioned (condition number < 1e6)."""
        return self.condition_number < 1e6


@dataclass(frozen=True)
class CompletenessResult:
    """Result of completeness axiom verification.

    The completeness axiom states that:
        I^T @ φ = g(z_0) - g(z_0 - I)

    This verifies that the attributions explain the full difference
    between the reference and perturbed predictions.

    Attributes:
        passed: Whether the completeness check passed (within tolerance).
        lhs: Left-hand side value (I^T @ φ).
        rhs: Right-hand side value (g(z_0) - g(z_0 - I)).
        absolute_error: |lhs - rhs|
        relative_error: |lhs - rhs| / |rhs|
    """

    passed: bool
    lhs: float
    rhs: float
    absolute_error: float
    relative_error: float

    @classmethod
    def from_tensors(
        cls,
        lhs: "Tensor",
        rhs: "Tensor",
        tolerance: float = 0.01,
    ) -> "CompletenessResult":
        """Create from tensor values.

        Args:
            lhs: Left-hand side tensor (scalar).
            rhs: Right-hand side tensor (scalar).
            tolerance: Relative error tolerance for passing.

        Returns:
            CompletenessResult instance.
        """
        lhs_val = float(lhs.item())
        rhs_val = float(rhs.item())
        abs_error = abs(lhs_val - rhs_val)
        rel_error = abs_error / max(abs(rhs_val), 1e-10)

        return cls(
            passed=rel_error <= tolerance,
            lhs=lhs_val,
            rhs=rhs_val,
            absolute_error=abs_error,
            relative_error=rel_error,
        )


@dataclass
class IntermediateValues:
    """Container for intermediate computation values (for research probes).

    These values are captured during computation when probes are enabled,
    allowing detailed analysis of the E-GradCAM algorithm.

    Attributes:
        I_samples: Perturbation samples [M, K].
        phi_samples: Attribution samples [M, K].
        M_I: Second moment matrix [K, K].
        b: Cross-moment vector [K].
        eigenvalues: Eigenvalues of M_I [K].
        eigenvectors: Eigenvectors of M_I [K, K].
        alpha_raw: Optimal weights before transform [K].
        alpha_transformed: Optimal weights after transform [K].
        coarse_heatmap: Heatmap at feature resolution [U, V].
        timings: Dict of operation timings in seconds.
        extra: Dict for additional values from probes.
    """

    I_samples: "Tensor | None" = None
    phi_samples: "Tensor | None" = None
    M_I: "Tensor | None" = None
    b: "Tensor | None" = None
    eigenvalues: "Tensor | None" = None
    eigenvectors: "Tensor | None" = None
    alpha_raw: "Tensor | None" = None
    alpha_transformed: "Tensor | None" = None
    coarse_heatmap: "Tensor | None" = None
    timings: dict[str, float] = field(default_factory=dict)
    extra: dict[str, Any] = field(default_factory=dict)

    def to_cpu(self) -> "IntermediateValues":
        """Move all tensors to CPU."""

        def _to_cpu(t: "Tensor | None") -> "Tensor | None":
            return t.cpu() if t is not None else None

        return IntermediateValues(
            I_samples=_to_cpu(self.I_samples),
            phi_samples=_to_cpu(self.phi_samples),
            M_I=_to_cpu(self.M_I),
            b=_to_cpu(self.b),
            eigenvalues=_to_cpu(self.eigenvalues),
            eigenvectors=_to_cpu(self.eigenvectors),
            alpha_raw=_to_cpu(self.alpha_raw),
            alpha_transformed=_to_cpu(self.alpha_transformed),
            coarse_heatmap=_to_cpu(self.coarse_heatmap),
            timings=self.timings.copy(),
            extra={k: _to_cpu(v) if isinstance(v, torch.Tensor) else v for k, v in self.extra.items()},
        )

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary for serialization."""
        result: dict[str, Any] = {}

        for name in [
            "I_samples",
            "phi_samples",
            "M_I",
            "b",
            "eigenvalues",
            "eigenvectors",
            "alpha_raw",
            "alpha_transformed",
            "coarse_heatmap",
        ]:
            tensor = getattr(self, name)
            if tensor is not None:
                result[name] = tensor.cpu().numpy()

        result["timings"] = self.timings
        result["extra"] = {
            k: v.cpu().numpy() if isinstance(v, torch.Tensor) else v
            for k, v in self.extra.items()
        }

        return result


@dataclass(frozen=True)
class ExpectedGradCAMResult:
    """Result container for Expected GradCAM computation.

    This is the main output of ExpectedGradCAM.__call__().

    Attributes:
        heatmap: Final heatmap at input resolution [H, W], normalized to [0, 1].
        coarse_heatmap: Heatmap at feature resolution [U, V].
        optimal_weights: Optimal feature map weights [K].
        target_class: Target class index for the explanation.
        feature_maps: Feature maps from target layer [1, K, U, V] (optional).
        solver_diagnostics: Diagnostics from the solver (optional).
        completeness_results: Completeness verification results (optional).
        intermediates: Intermediate values for research (optional).
    """

    heatmap: "Tensor"
    coarse_heatmap: "Tensor"
    optimal_weights: "Tensor"
    target_class: int
    feature_maps: "Tensor | None" = None
    solver_diagnostics: SolverDiagnostics | None = None
    completeness_results: list[CompletenessResult] | None = None
    intermediates: IntermediateValues | None = None

    @property
    def shape(self) -> tuple[int, int]:
        """Get the heatmap shape (H, W)."""
        return tuple(self.heatmap.shape)  # type: ignore

    @property
    def num_channels(self) -> int:
        """Get the number of feature channels K."""
        return int(self.optimal_weights.shape[0])

    def to_cpu(self) -> "ExpectedGradCAMResult":
        """Move all tensors to CPU."""
        return ExpectedGradCAMResult(
            heatmap=self.heatmap.cpu(),
            coarse_heatmap=self.coarse_heatmap.cpu(),
            optimal_weights=self.optimal_weights.cpu(),
            target_class=self.target_class,
            feature_maps=self.feature_maps.cpu() if self.feature_maps is not None else None,
            solver_diagnostics=self.solver_diagnostics,
            completeness_results=self.completeness_results,
            intermediates=self.intermediates.to_cpu() if self.intermediates is not None else None,
        )
