"""Computation-related exceptions."""

from __future__ import annotations

from expected_gradcam.exceptions.base import ExpectedGradCAMError


class ComputationError(ExpectedGradCAMError):
    """Base exception for numerical computation errors."""

    pass


class SingularMatrixError(ComputationError):
    """Raised when the second moment matrix M_I is singular or ill-conditioned.

    This error occurs when:
    - M_I has zero or near-zero eigenvalues
    - The condition number is too large for stable inversion
    - The effective rank is much smaller than K (number of channels)

    This is common when:
    - M (perturbation samples) is too small relative to K
    - Using pure data-aware perturbations (expected behavior, use pinv solver)

    Example:
        >>> raise SingularMatrixError(
        ...     condition_number=1e15,
        ...     rank=900,
        ...     K=2048
        ... )
    """

    def __init__(
        self,
        *,
        condition_number: float | None = None,
        rank: int | None = None,
        K: int | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            condition_number: Condition number of M_I.
            rank: Effective rank of M_I.
            K: Total number of channels (expected full rank).
        """
        self.condition_number = condition_number
        self.rank = rank
        self.K = K

        parts = ["Second moment matrix M_I is singular or ill-conditioned."]
        if condition_number is not None:
            parts.append(f"Condition number: {condition_number:.2e}")
        if rank is not None and K is not None:
            parts.append(f"Effective rank: {rank}/{K}")

        message = " ".join(parts)

        suggestion = (
            "Try one of the following:\n"
            "  1. Use solver_method='pinv' (pseudo-inverse, recommended for data-aware)\n"
            "  2. Increase M (perturbation samples) to M >= K\n"
            "  3. Increase regularization_eps (e.g., 1e-4)\n"
            "  4. Use solver_method='adaptive_reg' for automatic regularization"
        )

        super().__init__(message, suggestion=suggestion)


class NumericalInstabilityError(ComputationError):
    """Raised when numerical instability is detected.

    This error occurs when:
    - NaN or Inf values are produced during computation
    - Gradients explode or vanish unexpectedly
    - Intermediate values exceed safe bounds

    Example:
        >>> raise NumericalInstabilityError(
        ...     operation="path_integration",
        ...     details="NaN detected in gradient computation"
        ... )
    """

    def __init__(
        self,
        operation: str,
        *,
        details: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            operation: The operation where instability was detected.
            details: Additional details about the instability.
        """
        self.operation = operation
        self.details = details

        if details:
            message = f"Numerical instability in '{operation}': {details}"
        else:
            message = f"Numerical instability detected during '{operation}'."

        suggestion = (
            "Try one of the following:\n"
            "  1. Enable AMP (use_amp=True) for better numerical stability\n"
            "  2. Reduce the integration steps T\n"
            "  3. Check that the model produces reasonable outputs\n"
            "  4. Ensure input images are properly normalized"
        )

        super().__init__(message, suggestion=suggestion)


class GradientComputationError(ComputationError):
    """Raised when gradient computation fails.

    This error occurs when:
    - The computation graph is disconnected
    - Required tensors don't have grad_fn
    - Model parameters don't require gradients

    Example:
        >>> raise GradientComputationError(
        ...     reason="Computation graph disconnected"
        ... )
    """

    def __init__(self, *, reason: str | None = None) -> None:
        """Initialize the exception.

        Args:
            reason: Specific reason for the gradient computation failure.
        """
        self.reason = reason

        if reason:
            message = f"Failed to compute gradients: {reason}"
        else:
            message = "Failed to compute gradients."

        suggestion = (
            "Ensure that:\n"
            "  1. The model is in eval mode but gradients are enabled\n"
            "  2. Input tensors have requires_grad=True if needed\n"
            "  3. No in-place operations broke the computation graph\n"
            "  4. torch.no_grad() is not active during computation"
        )

        super().__init__(message, suggestion=suggestion)


class ConvergenceError(ComputationError):
    """Raised when an iterative algorithm fails to converge.

    This error occurs when:
    - The regularized solver doesn't converge within max iterations
    - Optimization-based weight computation fails

    Example:
        >>> raise ConvergenceError(
        ...     algorithm="adaptive_reg",
        ...     iterations=1000,
        ...     residual=0.5
        ... )
    """

    def __init__(
        self,
        algorithm: str,
        *,
        iterations: int | None = None,
        residual: float | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            algorithm: Name of the algorithm that failed to converge.
            iterations: Number of iterations performed.
            residual: Final residual/error value.
        """
        self.algorithm = algorithm
        self.iterations = iterations
        self.residual = residual

        parts = [f"Algorithm '{algorithm}' failed to converge."]
        if iterations is not None:
            parts.append(f"Iterations: {iterations}")
        if residual is not None:
            parts.append(f"Final residual: {residual:.2e}")

        message = " ".join(parts)

        suggestion = (
            "Try one of the following:\n"
            "  1. Use solver_method='pinv' instead (direct method)\n"
            "  2. Increase regularization_eps\n"
            "  3. Reduce the problem size (lower M or K)"
        )

        super().__init__(message, suggestion=suggestion)
