"""Custom exceptions for metrics computation.

This module defines the exception hierarchy for metric computation errors,
providing clear error messages and actionable suggestions.

Example:
    >>> raise InfidelityComputationError(
    ...     "Cannot compute infidelity: alpha has NaN values",
    ...     suggestion="Check that solver converged properly"
    ... )
"""

from __future__ import annotations

from expected_gradcam.exceptions.base import ExpectedGradCAMError


class MetricError(ExpectedGradCAMError):
    """Base exception for all metric-related errors.

    All metric exceptions inherit from this class, making it easy
    to catch all metric errors with a single except clause.
    """

    pass


class MetricComputationError(MetricError):
    """Error during metric computation.

    Raised when a metric cannot be computed due to invalid inputs,
    numerical issues, or other computation failures.

    Attributes:
        metric_name: Name of the metric that failed.
    """

    def __init__(
        self,
        metric_name: str,
        *,
        details: str = "",
        suggestion: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            metric_name: Name of the metric that failed.
            details: Additional details about the error.
            suggestion: Optional suggestion for fixing the issue.
        """
        message = f"Failed to compute metric '{metric_name}'"
        if details:
            message = f"{message}: {details}"
        super().__init__(message, suggestion=suggestion)
        self.metric_name = metric_name


class InfidelityComputationError(MetricComputationError):
    """Error during infidelity computation.

    Raised when the infidelity metric cannot be computed, typically
    due to NaN/Inf values, mismatched tensor shapes, or numerical issues.

    Example:
        >>> raise InfidelityComputationError(
        ...     "alpha weights contain NaN values",
        ...     suggestion="Check that M_I is invertible"
        ... )
    """

    def __init__(self, details: str = "", *, suggestion: str | None = None) -> None:
        """Initialize the exception.

        Args:
            details: Details about what went wrong.
            suggestion: Optional suggestion for fixing the issue.
        """
        super().__init__("infidelity", details=details, suggestion=suggestion)


class MetricNotFoundError(MetricError):
    """Metric not found in registry.

    Raised when attempting to access a metric that hasn't been registered.

    Attributes:
        name: The metric name that was not found.
        available: List of available metric names.
    """

    def __init__(self, name: str, available: list[str]) -> None:
        """Initialize the exception.

        Args:
            name: The metric name that was not found.
            available: List of available metric names.
        """
        message = f"Unknown metric: '{name}'"
        suggestion = f"Available metrics: {', '.join(sorted(available))}"
        super().__init__(message, suggestion=suggestion)
        self.name = name
        self.available = available


class InsufficientSamplesError(MetricComputationError):
    """Not enough samples for reliable metric computation.

    Raised when the number of samples is too low for statistical
    reliability of the computed metric.

    Attributes:
        required: Minimum number of samples required.
        actual: Actual number of samples provided.
    """

    def __init__(
        self,
        metric_name: str,
        required: int,
        actual: int,
        *,
        suggestion: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            metric_name: Name of the metric.
            required: Minimum number of samples required.
            actual: Actual number of samples provided.
            suggestion: Optional suggestion for fixing the issue.
        """
        details = f"requires at least {required} samples, got {actual}"
        if suggestion is None:
            suggestion = f"Increase M (perturbation samples) to at least {required}"
        super().__init__(metric_name, details=details, suggestion=suggestion)
        self.required = required
        self.actual = actual


class NumericalInstabilityError(MetricComputationError):
    """Numerical instability during metric computation.

    Raised when computation produces NaN, Inf, or other numerically
    unstable results.

    Attributes:
        value: The problematic value (NaN, Inf, etc.).
    """

    def __init__(
        self,
        metric_name: str,
        value: float | str,
        *,
        suggestion: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            metric_name: Name of the metric.
            value: The problematic value.
            suggestion: Optional suggestion for fixing the issue.
        """
        details = f"computed value is numerically unstable ({value})"
        if suggestion is None:
            suggestion = (
                "Try using regularization (solver_method='adaptive_reg') "
                "or check that inputs are finite"
            )
        super().__init__(metric_name, details=details, suggestion=suggestion)
        self.value = value


class InvalidMetricInputError(MetricComputationError):
    """Invalid input provided to metric computation.

    Raised when metric inputs have wrong shape, type, or contain
    invalid values.

    Attributes:
        param_name: Name of the invalid parameter.
        expected: Description of expected input.
        actual: Description of actual input received.
    """

    def __init__(
        self,
        metric_name: str,
        param_name: str,
        expected: str,
        actual: str,
        *,
        suggestion: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            metric_name: Name of the metric.
            param_name: Name of the invalid parameter.
            expected: Description of expected input.
            actual: Description of actual input received.
            suggestion: Optional suggestion for fixing the issue.
        """
        details = f"parameter '{param_name}' expected {expected}, got {actual}"
        super().__init__(metric_name, details=details, suggestion=suggestion)
        self.param_name = param_name
        self.expected = expected
        self.actual = actual
