"""Baseline provider and data-aware perturbation exceptions.

This module provides a comprehensive exception hierarchy for:
- Baseline provider discovery and initialization
- Data source access and validation
- Cache management
- Sampling operations

Exception Hierarchy:
    ExpectedGradCAMError (base)
    └── BaselineProviderError
        ├── ProviderNotFoundError
        ├── ProviderInitializationError
        ├── DataSourceError
        │   ├── DirectoryNotFoundError
        │   ├── EmptyBaselineDatasetError
        │   ├── InvalidBaselineImageError
        │   ├── TransformError
        │   └── HuggingFaceLoadError
        ├── BaselineValidationError
        │   ├── InsufficientSamplesError
        │   ├── BaselineDeviceMismatchError
        │   ├── DimensionMismatchError
        │   └── UnsupportedFormatError
        ├── CacheError
        │   ├── CacheCorruptedError
        │   └── CacheSizeExceededError
        └── SamplingError
            ├── CenteredConstraintViolation
            └── SamplingNumericalInstabilityError
"""

from __future__ import annotations

from enum import Enum, auto
from pathlib import Path
from typing import TYPE_CHECKING, Any

from expected_gradcam.exceptions.base import ExpectedGradCAMError

if TYPE_CHECKING:
    from collections.abc import Sequence


# =============================================================================
# Error Codes
# =============================================================================


class BaselineErrorCode(Enum):
    """Unique error codes for baseline provider exceptions.

    Error codes are organized by category:
    - BP_xxx: BaselineProvider errors (100-199)
    - DS_xxx: DataSource errors (200-299)
    - VL_xxx: Validation errors (300-399)
    - CA_xxx: Cache errors (400-499)
    - SA_xxx: Sampling errors (500-599)
    """

    # BaselineProvider errors (100-199)
    BP_PROVIDER_NOT_FOUND = auto()
    BP_INITIALIZATION_FAILED = auto()
    BP_REGISTRY_ERROR = auto()

    # DataSource errors (200-299)
    DS_DIRECTORY_NOT_FOUND = auto()
    DS_EMPTY_DATASET = auto()
    DS_INVALID_IMAGE = auto()
    DS_TRANSFORM_FAILED = auto()
    DS_HUGGINGFACE_LOAD_FAILED = auto()
    DS_PERMISSION_DENIED = auto()

    # Validation errors (300-399)
    VL_INSUFFICIENT_SAMPLES = auto()
    VL_DEVICE_MISMATCH = auto()
    VL_DIMENSION_MISMATCH = auto()
    VL_UNSUPPORTED_FORMAT = auto()
    VL_SHAPE_MISMATCH = auto()

    # Cache errors (400-499)
    CA_CORRUPTED = auto()
    CA_SIZE_EXCEEDED = auto()
    CA_VERSION_MISMATCH = auto()
    CA_WRITE_FAILED = auto()

    # Sampling errors (500-599)
    SA_CENTERED_CONSTRAINT = auto()
    SA_NUMERICAL_INSTABILITY = auto()
    SA_CONVERGENCE_FAILED = auto()


# =============================================================================
# Base Baseline Provider Exception
# =============================================================================


class BaselineProviderError(ExpectedGradCAMError):
    """Base exception for all baseline provider errors.

    All baseline provider exceptions inherit from this class, enabling
    catch-all handling of baseline-related errors while preserving
    specific exception types for granular error handling.

    Attributes:
        error_code: Unique error code for programmatic handling.
        context: Additional context about the error.

    Example:
        >>> try:
        ...     provider = get_provider("unknown")
        ... except BaselineProviderError as e:
        ...     print(f"Error {e.error_code.name}: {e.message}")
        ...     if e.suggestion:
        ...         print(f"Suggestion: {e.suggestion}")
    """

    def __init__(
        self,
        message: str,
        *,
        error_code: BaselineErrorCode,
        suggestion: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            message: The error message describing what went wrong.
            error_code: Unique error code for this exception type.
            suggestion: Optional suggestion for how to fix the issue.
            context: Additional context as key-value pairs.
        """
        self.error_code = error_code
        self.context = context or {}

        super().__init__(message, suggestion=suggestion)

    def __repr__(self) -> str:
        """Return a detailed representation of the exception."""
        parts = [f"{self.__class__.__name__}("]
        parts.append(f"error_code={self.error_code.name}, ")
        parts.append(f"message={self.message!r}")
        if self.context:
            parts.append(f", context={self.context!r}")
        parts.append(")")
        return "".join(parts)

    def to_dict(self) -> dict[str, Any]:
        """Convert exception to dictionary for logging/serialization.

        Returns:
            Dictionary with error details.
        """
        return {
            "error_code": self.error_code.name,
            "message": self.message,
            "suggestion": self.suggestion,
            "context": self.context,
            "exception_type": self.__class__.__name__,
        }


# =============================================================================
# Provider Discovery and Initialization Errors
# =============================================================================


class ProviderNotFoundError(BaselineProviderError):
    """Raised when a requested baseline provider cannot be found.

    This error occurs when:
    - The provider name is not registered in the provider registry
    - Auto-detection fails to find a suitable provider
    - The provider module cannot be imported

    Example:
        >>> raise ProviderNotFoundError(
        ...     "custom_provider",
        ...     available=["imagenet", "huggingface", "directory", "cached"]
        ... )
    """

    def __init__(
        self,
        provider_name: str,
        *,
        available: Sequence[str] | None = None,
        reason: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            provider_name: Name of the provider that was not found.
            available: List of available provider names.
            reason: Specific reason why the provider was not found.
        """
        self.provider_name = provider_name
        self.available = list(available) if available else []
        self.reason = reason

        message = f"Baseline provider '{provider_name}' not found."
        if reason:
            message = f"{message} {reason}"

        if self.available:
            available_str = ", ".join(f"'{p}'" for p in self.available)
            suggestion = (
                f"Available providers: {available_str}.\n"
                "You can also register a custom provider using:\n"
                "  from expected_gradcam.baselines import baseline_provider\n"
                "  @baseline_provider('my_provider')\n"
                "  class MyProvider(BaseProvider): ..."
            )
        else:
            suggestion = (
                "No providers are currently registered. Install baseline providers:\n"
                "  pip install expected-gradcam[hf]\n\n"
                "Or register a custom provider."
            )

        super().__init__(
            message,
            error_code=BaselineErrorCode.BP_PROVIDER_NOT_FOUND,
            suggestion=suggestion,
            context={"provider_name": provider_name, "available": self.available},
        )


class ProviderInitializationError(BaselineProviderError):
    """Raised when a baseline provider fails to initialize.

    This error occurs when:
    - Required configuration is missing
    - Dependencies are not installed
    - Resources (models, datasets) cannot be loaded
    - Invalid configuration is provided

    Example:
        >>> raise ProviderInitializationError(
        ...     "huggingface",
        ...     reason="Dataset 'imagenet-1k' requires authentication",
        ...     config={"dataset_name": "imagenet-1k", "split": "validation"}
        ... )
    """

    def __init__(
        self,
        provider_name: str,
        *,
        reason: str | None = None,
        config: dict[str, Any] | None = None,
        cause: Exception | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            provider_name: Name of the provider that failed to initialize.
            reason: Specific reason for the initialization failure.
            config: Configuration that was used during initialization.
            cause: The underlying exception that caused this error.
        """
        self.provider_name = provider_name
        self.reason = reason
        self.config = config or {}
        self.cause = cause

        if reason:
            message = f"Failed to initialize provider '{provider_name}': {reason}"
        else:
            message = f"Failed to initialize baseline provider '{provider_name}'."

        suggestion_parts = [
            "Check that:",
            "  1. All required dependencies are installed",
            "  2. Configuration parameters are valid",
            "  3. Required resources (datasets, models) are accessible",
        ]

        if cause:
            suggestion_parts.append(
                f"\nUnderlying error: {type(cause).__name__}: {cause}"
            )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.BP_INITIALIZATION_FAILED,
            suggestion=suggestion,
            context={
                "provider_name": provider_name,
                "config": self.config,
                "cause_type": type(cause).__name__ if cause else None,
            },
        )

        # Chain the exception
        if cause:
            self.__cause__ = cause


# =============================================================================
# Data Source Errors
# =============================================================================


class DataSourceError(BaselineProviderError):
    """Base exception for data source errors.

    This is the parent class for all exceptions related to loading,
    reading, or processing data from various sources (directories,
    HuggingFace, URLs, etc.).

    Example:
        >>> try:
        ...     dataset = load_baseline_data(source)
        ... except DataSourceError as e:
        ...     print(f"Data loading failed: {e.message}")
    """

    def __init__(
        self,
        message: str,
        *,
        error_code: BaselineErrorCode,
        source: str | Path | None = None,
        suggestion: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            message: The error message.
            error_code: Unique error code.
            source: The data source that caused the error.
            suggestion: Suggestion for fixing the issue.
            context: Additional context.
        """
        self.source = str(source) if source else None

        ctx = context or {}
        if self.source:
            ctx["source"] = self.source

        super().__init__(
            message,
            error_code=error_code,
            suggestion=suggestion,
            context=ctx,
        )


class DirectoryNotFoundError(DataSourceError):
    """Raised when a specified directory does not exist.

    This error occurs when:
    - The baseline image directory path does not exist
    - The path exists but is not a directory
    - The path is inaccessible due to permissions

    Example:
        >>> raise DirectoryNotFoundError(
        ...     "/path/to/imagenet/train",
        ...     searched_paths=[
        ...         "/data/imagenet/train",
        ...         "/mnt/datasets/imagenet/train",
        ...     ]
        ... )
    """

    def __init__(
        self,
        path: str | Path,
        *,
        searched_paths: Sequence[str | Path] | None = None,
        is_file: bool = False,
    ) -> None:
        """Initialize the exception.

        Args:
            path: The directory path that was not found.
            searched_paths: List of paths that were searched.
            is_file: True if the path exists but is a file, not a directory.
        """
        self.path = Path(path)
        self.searched_paths = (
            [Path(p) for p in searched_paths] if searched_paths else []
        )
        self.is_file = is_file

        if is_file:
            message = f"Path '{path}' exists but is a file, not a directory."
            suggestion = "Provide a directory path containing baseline images."
        else:
            message = f"Directory not found: '{path}'"

            suggestion_parts = ["Ensure the directory exists and is accessible."]

            if self.searched_paths:
                paths_str = "\n  ".join(str(p) for p in self.searched_paths)
                suggestion_parts.append(f"\nSearched paths:\n  {paths_str}")

            suggestion_parts.extend(
                [
                    "\nYou can also:",
                    "  1. Use from_imagenet(imagenet_root) for ImageNet",
                    "  2. Use HuggingFaceProvider for cloud datasets",
                    "  3. Create the directory and populate it with images",
                ]
            )
            suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.DS_DIRECTORY_NOT_FOUND,
            source=path,
            suggestion=suggestion,
            context={
                "searched_paths": [str(p) for p in self.searched_paths],
                "is_file": is_file,
            },
        )


class EmptyBaselineDatasetError(DataSourceError):
    """Raised when a baseline dataset is empty or has no valid samples.

    Example:
        >>> raise EmptyBaselineDatasetError(
        ...     source="/path/to/images",
        ...     reason="No files matching extensions: .jpg, .jpeg, .png",
        ...     total_files=150,
        ...     filtered_out=150
        ... )
    """

    def __init__(
        self,
        source: str | Path,
        *,
        reason: str | None = None,
        total_files: int | None = None,
        filtered_out: int | None = None,
        extensions: Sequence[str] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            source: The data source (directory, dataset name, etc.).
            reason: Specific reason why the dataset is empty.
            total_files: Total number of files found before filtering.
            filtered_out: Number of files filtered out.
            extensions: File extensions that were searched for.
        """
        self.reason = reason
        self.total_files = total_files
        self.filtered_out = filtered_out
        self.extensions = list(extensions) if extensions else []

        parts = [f"Baseline dataset '{source}' is empty"]
        if reason:
            parts.append(f": {reason}")
        else:
            parts.append(" or has no valid samples.")

        message = "".join(parts)

        suggestion_parts = ["Ensure the dataset:"]
        suggestion_parts.append("  1. Contains valid image files")

        if self.extensions:
            ext_str = ", ".join(self.extensions)
            suggestion_parts.append(f"  2. Has files with supported extensions: {ext_str}")
        else:
            suggestion_parts.append(
                "  2. Has files with extensions: .jpg, .jpeg, .png, .JPEG"
            )

        suggestion_parts.append("  3. Has proper read permissions")

        if total_files is not None and filtered_out is not None:
            suggestion_parts.append(
                f"\nFound {total_files} files, {filtered_out} were filtered out."
            )
            suggestion_parts.append(
                "Check your filtering criteria or image validation settings."
            )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.DS_EMPTY_DATASET,
            source=source,
            suggestion=suggestion,
            context={
                "total_files": total_files,
                "filtered_out": filtered_out,
                "extensions": self.extensions,
            },
        )


class InvalidBaselineImageError(DataSourceError):
    """Raised when a baseline image cannot be processed.

    Example:
        >>> raise InvalidBaselineImageError(
        ...     path="/path/to/image.jpg",
        ...     reason="Image has 4 channels (RGBA), expected 3 (RGB)",
        ...     image_info={"shape": (224, 224, 4), "mode": "RGBA"}
        ... )
    """

    def __init__(
        self,
        path: str | Path | None = None,
        *,
        reason: str | None = None,
        image_info: dict[str, Any] | None = None,
        batch_index: int | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            path: Path to the invalid image.
            reason: Specific reason why the image is invalid.
            image_info: Information about the image (shape, mode, etc.).
            batch_index: Index in batch if processing multiple images.
        """
        self.path = Path(path) if path else None
        self.reason = reason
        self.image_info = image_info or {}
        self.batch_index = batch_index

        parts = ["Invalid baseline image"]
        if path:
            parts.append(f" '{path}'")
        if batch_index is not None:
            parts.append(f" (batch index {batch_index})")
        if reason:
            parts.append(f": {reason}")
        else:
            parts.append(".")

        message = "".join(parts)

        suggestion_parts = ["Ensure the image:"]
        suggestion_parts.extend(
            [
                "  1. Is a valid image file (not corrupted or truncated)",
                "  2. Has 3 channels (RGB) or can be converted to RGB",
                "  3. Has reasonable dimensions (not 0x0)",
                "  4. Can be opened by PIL/Pillow",
            ]
        )

        if self.image_info:
            info_str = ", ".join(f"{k}={v}" for k, v in self.image_info.items())
            suggestion_parts.append(f"\nImage info: {info_str}")

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.DS_INVALID_IMAGE,
            source=path,
            suggestion=suggestion,
            context={
                "image_info": self.image_info,
                "batch_index": batch_index,
            },
        )


class TransformError(DataSourceError):
    """Raised when an image transform fails.

    This error occurs when:
    - A transform in the preprocessing pipeline fails
    - Input image has incompatible dimensions
    - Transform produces invalid output (NaN, Inf)

    Example:
        >>> raise TransformError(
        ...     transform_name="Normalize",
        ...     reason="Input tensor contains NaN values",
        ...     input_shape=(3, 224, 224)
        ... )
    """

    def __init__(
        self,
        transform_name: str,
        *,
        reason: str | None = None,
        input_shape: tuple[int, ...] | None = None,
        transform_config: dict[str, Any] | None = None,
        cause: Exception | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            transform_name: Name of the transform that failed.
            reason: Specific reason for the failure.
            input_shape: Shape of the input to the transform.
            transform_config: Configuration of the transform.
            cause: The underlying exception.
        """
        self.transform_name = transform_name
        self.reason = reason
        self.input_shape = input_shape
        self.transform_config = transform_config or {}
        self.cause = cause

        parts = [f"Transform '{transform_name}' failed"]
        if reason:
            parts.append(f": {reason}")
        message = "".join(parts)

        suggestion_parts = [
            "Check that:",
            "  1. Input image has expected shape and type",
            "  2. Transform configuration is valid",
            "  3. Input values are in expected range (e.g., [0, 1] or [0, 255])",
        ]

        if input_shape:
            suggestion_parts.append(f"\nInput shape: {input_shape}")

        if cause:
            suggestion_parts.append(
                f"\nUnderlying error: {type(cause).__name__}: {cause}"
            )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.DS_TRANSFORM_FAILED,
            suggestion=suggestion,
            context={
                "transform_name": transform_name,
                "input_shape": input_shape,
                "transform_config": self.transform_config,
            },
        )

        if cause:
            self.__cause__ = cause


class HuggingFaceLoadError(DataSourceError):
    """Raised when loading a HuggingFace dataset fails.

    This error occurs when:
    - Dataset does not exist on HuggingFace Hub
    - Authentication is required but not provided
    - Network errors during download
    - Dataset format is incompatible

    Example:
        >>> raise HuggingFaceLoadError(
        ...     dataset_name="imagenet-1k",
        ...     reason="This dataset requires authentication",
        ...     requires_auth=True
        ... )
    """

    def __init__(
        self,
        dataset_name: str,
        *,
        reason: str | None = None,
        split: str | None = None,
        config_name: str | None = None,
        requires_auth: bool = False,
        cause: Exception | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            dataset_name: Name of the HuggingFace dataset.
            reason: Specific reason for the failure.
            split: Dataset split that was requested.
            config_name: Configuration name for datasets with multiple configs.
            requires_auth: Whether the dataset requires authentication.
            cause: The underlying exception.
        """
        self.dataset_name = dataset_name
        self.reason = reason
        self.split = split
        self.config_name = config_name
        self.requires_auth = requires_auth
        self.cause = cause

        parts = [f"Failed to load HuggingFace dataset '{dataset_name}'"]
        if config_name:
            parts.append(f" (config: {config_name})")
        if split:
            parts.append(f" split '{split}'")
        if reason:
            parts.append(f": {reason}")
        message = "".join(parts)

        suggestion_parts = []

        if requires_auth:
            suggestion_parts.extend(
                [
                    "This dataset requires authentication:",
                    "  1. Run: huggingface-cli login",
                    f"  2. Accept the dataset terms at: "
                    f"https://huggingface.co/datasets/{dataset_name}",
                    "  3. Retry the operation",
                ]
            )
        else:
            suggestion_parts.extend(
                [
                    "Check that:",
                    f"  1. Dataset '{dataset_name}' exists on HuggingFace Hub",
                    "  2. You have internet connectivity",
                    "  3. The datasets library is installed: pip install datasets",
                ]
            )

        if cause:
            suggestion_parts.append(
                f"\nUnderlying error: {type(cause).__name__}: {cause}"
            )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.DS_HUGGINGFACE_LOAD_FAILED,
            source=f"huggingface://{dataset_name}",
            suggestion=suggestion,
            context={
                "dataset_name": dataset_name,
                "split": split,
                "config_name": config_name,
                "requires_auth": requires_auth,
            },
        )

        if cause:
            self.__cause__ = cause


# =============================================================================
# Validation Errors
# =============================================================================


class BaselineValidationError(BaselineProviderError):
    """Base exception for validation errors.

    Raised when data or configuration fails validation checks.
    """

    def __init__(
        self,
        message: str,
        *,
        error_code: BaselineErrorCode,
        field: str | None = None,
        suggestion: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            message: The error message.
            error_code: Unique error code.
            field: The field or parameter that failed validation.
            suggestion: Suggestion for fixing the issue.
            context: Additional context.
        """
        self.field = field

        ctx = context or {}
        if field:
            ctx["field"] = field

        super().__init__(
            message,
            error_code=error_code,
            suggestion=suggestion,
            context=ctx,
        )


class InsufficientSamplesError(BaselineValidationError):
    """Raised when there are not enough samples for the operation.

    This error occurs when:
    - Dataset has fewer samples than required for baseline sampling
    - M (perturbation samples) cannot be satisfied
    - N (baseline samples) exceeds available data

    Example:
        >>> raise InsufficientSamplesError(
        ...     required=1000,
        ...     available=50,
        ...     parameter="M"
        ... )
    """

    def __init__(
        self,
        required: int,
        available: int,
        *,
        parameter: str | None = None,
        operation: str | None = None,
        allow_replacement: bool = True,
    ) -> None:
        """Initialize the exception.

        Args:
            required: Number of samples required.
            available: Number of samples available.
            parameter: Parameter name (M, N, etc.) that requires samples.
            operation: The operation requiring samples.
            allow_replacement: Whether sampling with replacement is allowed.
        """
        self.required = required
        self.available = available
        self.parameter = parameter
        self.operation = operation
        self.allow_replacement = allow_replacement

        parts = [f"Insufficient samples: required {required}, available {available}"]
        if parameter:
            parts.insert(0, f"Parameter '{parameter}': ")
        if operation:
            parts.append(f" for {operation}")
        message = "".join(parts)

        suggestion_parts = ["Options to resolve this:"]

        if allow_replacement:
            suggestion_parts.append(
                "  1. Enable sampling with replacement (sample_with_replacement=True)"
            )
            suggestion_parts.append(
                f"  2. Reduce {parameter or 'sample count'} to <= {available}"
            )
        else:
            suggestion_parts.append(
                f"  1. Reduce {parameter or 'sample count'} to <= {available}"
            )

        suggestion_parts.extend(
            [
                "  3. Add more images to the baseline dataset",
                "  4. Use a larger baseline dataset (e.g., ImageNet train split)",
            ]
        )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.VL_INSUFFICIENT_SAMPLES,
            field=parameter,
            suggestion=suggestion,
            context={
                "required": required,
                "available": available,
                "operation": operation,
                "allow_replacement": allow_replacement,
            },
        )


class BaselineDeviceMismatchError(BaselineValidationError):
    """Raised when baseline tensors are on different devices.

    Example:
        >>> raise BaselineDeviceMismatchError(
        ...     expected_device="cuda:0",
        ...     actual_device="cpu",
        ...     tensor_name="baseline_features"
        ... )
    """

    def __init__(
        self,
        expected_device: str,
        actual_device: str,
        *,
        tensor_name: str | None = None,
        tensors_info: dict[str, str] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            expected_device: Device where tensor should be.
            actual_device: Device where tensor actually is.
            tensor_name: Name of the mismatched tensor.
            tensors_info: Dictionary of tensor names to their devices.
        """
        self.expected_device = expected_device
        self.actual_device = actual_device
        self.tensor_name = tensor_name
        self.tensors_info = tensors_info or {}

        if tensor_name:
            message = (
                f"Device mismatch for baseline '{tensor_name}': "
                f"expected {expected_device}, got {actual_device}."
            )
        else:
            message = (
                f"Baseline device mismatch: expected {expected_device}, "
                f"got {actual_device}."
            )

        suggestion_parts = [
            f"Move tensor to correct device: tensor = tensor.to('{expected_device}')"
        ]

        if self.tensors_info:
            info_str = "\n  ".join(f"{k}: {v}" for k, v in self.tensors_info.items())
            suggestion_parts.append(f"\nCurrent device assignments:\n  {info_str}")

        suggestion_parts.append(
            "\nEnsure all tensors (model, input, baselines) are on the same device."
        )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.VL_DEVICE_MISMATCH,
            field=tensor_name,
            suggestion=suggestion,
            context={
                "expected_device": expected_device,
                "actual_device": actual_device,
                "tensors_info": self.tensors_info,
            },
        )


class DimensionMismatchError(BaselineValidationError):
    """Raised when tensor dimensions do not match expected values.

    This error occurs when:
    - Feature map channels (K) don't match between input and baseline
    - Spatial dimensions are incompatible
    - Batch dimensions are unexpected

    Example:
        >>> raise DimensionMismatchError(
        ...     expected_shape=(1, 2048, 7, 7),
        ...     actual_shape=(1, 512, 14, 14),
        ...     tensor_name="baseline_features"
        ... )
    """

    def __init__(
        self,
        expected_shape: tuple[int, ...],
        actual_shape: tuple[int, ...],
        *,
        tensor_name: str | None = None,
        dimension_names: Sequence[str] | None = None,
        mismatch_dims: Sequence[int] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            expected_shape: Expected tensor shape.
            actual_shape: Actual tensor shape.
            tensor_name: Name of the tensor with wrong dimensions.
            dimension_names: Human-readable names for each dimension.
            mismatch_dims: Indices of dimensions that don't match.
        """
        self.expected_shape = expected_shape
        self.actual_shape = actual_shape
        self.tensor_name = tensor_name
        self.dimension_names = list(dimension_names) if dimension_names else []
        self.mismatch_dims = list(mismatch_dims) if mismatch_dims else []

        parts = ["Dimension mismatch"]
        if tensor_name:
            parts.append(f" for '{tensor_name}'")
        parts.append(f": expected {expected_shape}, got {actual_shape}")
        message = "".join(parts)

        suggestion_parts = []

        if self.dimension_names and self.mismatch_dims:
            mismatches = []
            for dim in self.mismatch_dims:
                if dim < len(self.dimension_names) and dim < len(expected_shape):
                    name = self.dimension_names[dim]
                    mismatches.append(
                        f"  - {name} (dim {dim}): expected {expected_shape[dim]}, "
                        f"got {actual_shape[dim]}"
                    )
            if mismatches:
                suggestion_parts.append("Mismatched dimensions:")
                suggestion_parts.extend(mismatches)

        suggestion_parts.extend(
            [
                "\nEnsure that:",
                "  1. Baseline features are extracted from the same target layer",
                "  2. Input image size matches baseline image size",
                "  3. The same model architecture is used for baseline extraction",
            ]
        )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.VL_DIMENSION_MISMATCH,
            field=tensor_name,
            suggestion=suggestion,
            context={
                "expected_shape": expected_shape,
                "actual_shape": actual_shape,
                "dimension_names": self.dimension_names,
                "mismatch_dims": self.mismatch_dims,
            },
        )


class UnsupportedFormatError(BaselineValidationError):
    """Raised when a data format is not supported.

    This error occurs when:
    - Image format is not supported
    - Cache file format is unknown
    - Dataset format is incompatible

    Example:
        >>> raise UnsupportedFormatError(
        ...     format_name="webp",
        ...     supported_formats=["jpg", "jpeg", "png", "JPEG"],
        ...     context_type="image"
        ... )
    """

    def __init__(
        self,
        format_name: str,
        *,
        supported_formats: Sequence[str] | None = None,
        context_type: str = "format",
    ) -> None:
        """Initialize the exception.

        Args:
            format_name: The unsupported format.
            supported_formats: List of supported formats.
            context_type: Type of format (image, cache, dataset).
        """
        self.format_name = format_name
        self.supported_formats = list(supported_formats) if supported_formats else []
        self.context_type = context_type

        message = f"Unsupported {context_type} format: '{format_name}'"

        if self.supported_formats:
            formats_str = ", ".join(f"'{f}'" for f in self.supported_formats)
            suggestion = f"Supported {context_type} formats: {formats_str}"
        else:
            suggestion = f"Check documentation for supported {context_type} formats."

        super().__init__(
            message,
            error_code=BaselineErrorCode.VL_UNSUPPORTED_FORMAT,
            field="format",
            suggestion=suggestion,
            context={
                "format_name": format_name,
                "supported_formats": self.supported_formats,
                "context_type": context_type,
            },
        )


# =============================================================================
# Cache Errors
# =============================================================================


class CacheError(BaselineProviderError):
    """Base exception for cache-related errors."""

    def __init__(
        self,
        message: str,
        *,
        error_code: BaselineErrorCode,
        cache_path: str | Path | None = None,
        suggestion: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            message: The error message.
            error_code: Unique error code.
            cache_path: Path to the cache file.
            suggestion: Suggestion for fixing the issue.
            context: Additional context.
        """
        self.cache_path = Path(cache_path) if cache_path else None

        ctx = context or {}
        if self.cache_path:
            ctx["cache_path"] = str(self.cache_path)

        super().__init__(
            message,
            error_code=error_code,
            suggestion=suggestion,
            context=ctx,
        )


class CacheCorruptedError(CacheError):
    """Raised when a cache file is corrupted or unreadable.

    This error occurs when:
    - Cache file header is invalid
    - Data integrity check fails
    - File format version is incompatible
    - File is truncated or partially written

    Example:
        >>> raise CacheCorruptedError(
        ...     cache_path="/path/to/features.npy",
        ...     reason="Checksum mismatch",
        ...     recoverable=True
        ... )
    """

    def __init__(
        self,
        cache_path: str | Path,
        *,
        reason: str | None = None,
        expected_version: str | None = None,
        actual_version: str | None = None,
        recoverable: bool = True,
    ) -> None:
        """Initialize the exception.

        Args:
            cache_path: Path to the corrupted cache file.
            reason: Specific reason for corruption.
            expected_version: Expected cache format version.
            actual_version: Actual cache format version.
            recoverable: Whether the cache can be rebuilt.
        """
        self.reason = reason
        self.expected_version = expected_version
        self.actual_version = actual_version
        self.recoverable = recoverable

        parts = [f"Cache file corrupted: '{cache_path}'"]
        if reason:
            parts.append(f". {reason}")
        message = "".join(parts)

        suggestion_parts = []

        if recoverable:
            suggestion_parts.extend(
                [
                    "The cache can be rebuilt:",
                    f"  1. Delete the corrupted cache: rm '{cache_path}'",
                    "  2. Re-run the operation to rebuild the cache",
                    "  3. Or use force_rebuild=True in the provider constructor",
                ]
            )
        else:
            suggestion_parts.extend(
                [
                    "The cache cannot be automatically recovered:",
                    f"  1. Delete the corrupted cache: rm '{cache_path}'",
                    "  2. Re-extract features from the baseline dataset",
                ]
            )

        if expected_version and actual_version:
            suggestion_parts.append(
                f"\nVersion mismatch: expected {expected_version}, got {actual_version}"
            )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.CA_CORRUPTED,
            cache_path=cache_path,
            suggestion=suggestion,
            context={
                "expected_version": expected_version,
                "actual_version": actual_version,
                "recoverable": recoverable,
            },
        )


class CacheSizeExceededError(CacheError):
    """Raised when a cache operation would exceed size limits.

    This error occurs when:
    - Cache file would exceed disk space
    - Memory-mapped cache exceeds available RAM
    - Cache size exceeds configured maximum

    Example:
        >>> raise CacheSizeExceededError(
        ...     cache_path="/path/to/features.npy",
        ...     required_bytes=50 * 1024**3,  # 50 GB
        ...     available_bytes=10 * 1024**3,  # 10 GB
        ... )
    """

    def __init__(
        self,
        cache_path: str | Path,
        *,
        required_bytes: int,
        available_bytes: int | None = None,
        max_bytes: int | None = None,
        limit_type: str = "disk",
    ) -> None:
        """Initialize the exception.

        Args:
            cache_path: Path to the cache file.
            required_bytes: Bytes required for the cache.
            available_bytes: Bytes currently available.
            max_bytes: Maximum allowed cache size.
            limit_type: Type of limit (disk, memory, config).
        """
        self.required_bytes = required_bytes
        self.available_bytes = available_bytes
        self.max_bytes = max_bytes
        self.limit_type = limit_type

        def format_bytes(b: int) -> str:
            for unit in ["B", "KB", "MB", "GB", "TB"]:
                if b < 1024:
                    return f"{b:.1f} {unit}"
                b /= 1024
            return f"{b:.1f} PB"

        required_str = format_bytes(required_bytes)

        parts = [f"Cache size limit exceeded: requires {required_str}"]

        if available_bytes is not None:
            parts.append(f", only {format_bytes(available_bytes)} available")
        elif max_bytes is not None:
            parts.append(f", maximum is {format_bytes(max_bytes)}")

        message = "".join(parts)

        suggestion_parts = [f"The cache requires {required_str} of {limit_type} space."]

        if limit_type == "disk":
            suggestion_parts.extend(
                [
                    "\nOptions:",
                    "  1. Free up disk space",
                    "  2. Use a different cache location with more space",
                    "  3. Reduce cache size by sampling fewer baseline images",
                    "  4. Use compressed cache format (compress=True)",
                ]
            )
        elif limit_type == "memory":
            suggestion_parts.extend(
                [
                    "\nOptions:",
                    "  1. Use memory-mapped cache (mmap_mode='r')",
                    "  2. Reduce batch size for feature extraction",
                    "  3. Sample fewer baseline images",
                ]
            )

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.CA_SIZE_EXCEEDED,
            cache_path=cache_path,
            suggestion=suggestion,
            context={
                "required_bytes": required_bytes,
                "available_bytes": available_bytes,
                "max_bytes": max_bytes,
                "limit_type": limit_type,
            },
        )


# =============================================================================
# Sampling Errors
# =============================================================================


class SamplingError(BaselineProviderError):
    """Base exception for sampling-related errors."""

    def __init__(
        self,
        message: str,
        *,
        error_code: BaselineErrorCode,
        suggestion: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize the exception."""
        super().__init__(
            message,
            error_code=error_code,
            suggestion=suggestion,
            context=context,
        )


class CenteredConstraintViolation(SamplingError):
    """Raised when samples violate the centering constraint E[z'] = 0.

    For Expected Gradients to satisfy the completeness axiom, baseline
    samples must be centered. This exception is raised when:
    - Mean of samples exceeds tolerance threshold
    - Centering cannot be achieved (e.g., single sample)
    - Post-centering check fails

    Example:
        >>> raise CenteredConstraintViolation(
        ...     mean_norm=0.15,
        ...     tolerance=0.01,
        ...     n_samples=20
        ... )
    """

    def __init__(
        self,
        mean_norm: float,
        tolerance: float,
        *,
        n_samples: int | None = None,
        dimension: int | None = None,
        max_dim_violation: float | None = None,
        violating_dims: Sequence[int] | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            mean_norm: L2 norm of the sample mean.
            tolerance: Maximum allowed mean norm.
            n_samples: Number of samples.
            dimension: Dimensionality of samples.
            max_dim_violation: Maximum violation in any single dimension.
            violating_dims: Indices of dimensions exceeding tolerance.
        """
        self.mean_norm = mean_norm
        self.tolerance = tolerance
        self.n_samples = n_samples
        self.dimension = dimension
        self.max_dim_violation = max_dim_violation
        self.violating_dims = list(violating_dims) if violating_dims else []

        parts = [
            f"Centering constraint violated: ||E[z']|| = {mean_norm:.2e} "
            f"(tolerance: {tolerance:.2e})"
        ]

        if n_samples is not None and n_samples == 1:
            parts.append(". Cannot center a single sample.")

        message = "".join(parts)

        suggestion_parts = [
            "For Expected Gradients completeness, baselines must satisfy E[z'] = 0.",
            "\nOptions:",
        ]

        if n_samples == 1:
            suggestion_parts.append(
                "  1. Increase N (number of baseline samples) to >= 2"
            )
        else:
            suggestion_parts.extend(
                [
                    "  1. Apply centering: samples = samples - samples.mean(dim=0)",
                    "  2. Use center_samples() utility function",
                    "  3. Use CenteredBaselineSampler which auto-centers",
                ]
            )

        if self.violating_dims:
            dims_str = ", ".join(str(d) for d in self.violating_dims[:5])
            if len(self.violating_dims) > 5:
                dims_str += f", ... ({len(self.violating_dims)} total)"
            suggestion_parts.append(f"\nViolating dimensions: {dims_str}")

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.SA_CENTERED_CONSTRAINT,
            suggestion=suggestion,
            context={
                "mean_norm": mean_norm,
                "tolerance": tolerance,
                "n_samples": n_samples,
                "dimension": dimension,
                "max_dim_violation": max_dim_violation,
                "n_violating_dims": len(self.violating_dims),
            },
        )


class SamplingNumericalInstabilityError(SamplingError):
    """Raised when numerical instability is detected during sampling.

    Example:
        >>> raise SamplingNumericalInstabilityError(
        ...     operation="perturbation_sampling",
        ...     details="Division by near-zero GAP value",
        ...     has_nan=True
        ... )
    """

    def __init__(
        self,
        operation: str,
        *,
        details: str | None = None,
        affected_channels: Sequence[int] | None = None,
        values_info: dict[str, float] | None = None,
        has_nan: bool = False,
        has_inf: bool = False,
    ) -> None:
        """Initialize the exception.

        Args:
            operation: The operation where instability was detected.
            details: Additional details about the instability.
            affected_channels: Channel indices with problematic values.
            values_info: Statistics about the values (min, max, etc.).
            has_nan: Whether NaN values were detected.
            has_inf: Whether Inf values were detected.
        """
        self.operation = operation
        self.details = details
        self.affected_channels = (
            list(affected_channels) if affected_channels else []
        )
        self.values_info = values_info or {}
        self.has_nan = has_nan
        self.has_inf = has_inf

        parts = [f"Numerical instability in sampling '{operation}'"]
        if details:
            parts.append(f": {details}")
        if has_nan:
            parts.append(" [NaN detected]")
        if has_inf:
            parts.append(" [Inf detected]")
        message = "".join(parts)

        suggestion_parts = ["Potential causes and solutions:"]

        if has_nan or has_inf:
            suggestion_parts.extend(
                [
                    "  1. Check for division by zero in GAP normalization",
                    "  2. Use safe_divide() utility for protected division",
                    "  3. Clip extreme values before operations",
                ]
            )
        else:
            suggestion_parts.extend(
                [
                    "  1. Increase epsilon for numerical stability (e.g., 1e-6)",
                    "  2. Use normalize_perturbations() to scale values",
                    "  3. Check that input features are properly normalized",
                ]
            )

        if self.values_info:
            info_str = ", ".join(f"{k}={v:.2e}" for k, v in self.values_info.items())
            suggestion_parts.append(f"\nValue statistics: {info_str}")

        if self.affected_channels:
            channels_str = ", ".join(str(c) for c in self.affected_channels[:5])
            if len(self.affected_channels) > 5:
                channels_str += f", ... ({len(self.affected_channels)} total)"
            suggestion_parts.append(f"\nAffected channels: {channels_str}")

        suggestion = "\n".join(suggestion_parts)

        super().__init__(
            message,
            error_code=BaselineErrorCode.SA_NUMERICAL_INSTABILITY,
            suggestion=suggestion,
            context={
                "operation": operation,
                "has_nan": has_nan,
                "has_inf": has_inf,
                "n_affected_channels": len(self.affected_channels),
                "values_info": self.values_info,
            },
        )


# =============================================================================
# Module Exports
# =============================================================================

__all__ = [
    # Error codes
    "BaselineErrorCode",
    # Base exceptions
    "BaselineProviderError",
    "DataSourceError",
    "BaselineValidationError",
    "CacheError",
    "SamplingError",
    # Provider errors
    "ProviderNotFoundError",
    "ProviderInitializationError",
    # Data source errors
    "DirectoryNotFoundError",
    "EmptyBaselineDatasetError",
    "InvalidBaselineImageError",
    "TransformError",
    "HuggingFaceLoadError",
    # Validation errors
    "InsufficientSamplesError",
    "BaselineDeviceMismatchError",
    "DimensionMismatchError",
    "UnsupportedFormatError",
    # Cache errors
    "CacheCorruptedError",
    "CacheSizeExceededError",
    # Sampling errors
    "CenteredConstraintViolation",
    "SamplingNumericalInstabilityError",
]
