"""Data and input-related exceptions."""

from __future__ import annotations

from typing import TYPE_CHECKING

from expected_gradcam.exceptions.base import ExpectedGradCAMError


if TYPE_CHECKING:
    from collections.abc import Sequence


class DataError(ExpectedGradCAMError):
    """Base exception for data-related errors."""

    pass


class InvalidInputShapeError(DataError):
    """Raised when input tensor has an invalid shape.

    Expected GradCAM expects input tensors in NCHW format:
    - N: Batch size (must be 1 for single image)
    - C: Channels (typically 3 for RGB)
    - H: Height
    - W: Width

    Example:
        >>> raise InvalidInputShapeError(
        ...     expected=(1, 3, 224, 224),
        ...     actual=(3, 224, 224)
        ... )
    """

    def __init__(
        self,
        expected: tuple[int | str, ...],
        actual: tuple[int, ...],
    ) -> None:
        """Initialize the exception.

        Args:
            expected: Expected shape (can use strings like "C", "H", "W").
            actual: Actual shape of the input tensor.
        """
        self.expected = expected
        self.actual = actual

        expected_str = "(" + ", ".join(str(x) for x in expected) + ")"
        actual_str = "(" + ", ".join(str(x) for x in actual) + ")"

        message = f"Invalid input shape. Expected {expected_str}, got {actual_str}."

        # Provide specific suggestions based on the mismatch
        suggestions = []
        if len(actual) == 3:
            suggestions.append("Add batch dimension: x = x.unsqueeze(0)")
        if len(actual) == 4 and actual[0] != 1:
            suggestions.append("Process images one at a time (batch_size=1)")

        suggestion = "\n".join(suggestions) if suggestions else None

        super().__init__(message, suggestion=suggestion)


class EmptyDatasetError(DataError):
    """Raised when a dataset is empty or has no valid samples.

    Example:
        >>> raise EmptyDatasetError(
        ...     dataset_name="baseline_dataset",
        ...     reason="Directory contains no image files"
        ... )
    """

    def __init__(
        self,
        dataset_name: str,
        *,
        reason: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            dataset_name: Name or description of the dataset.
            reason: Specific reason why the dataset is empty.
        """
        self.dataset_name = dataset_name
        self.reason = reason

        if reason:
            message = f"Dataset '{dataset_name}' is empty: {reason}"
        else:
            message = f"Dataset '{dataset_name}' is empty or has no valid samples."

        suggestion = (
            "Ensure the dataset:\n"
            "  1. Contains valid image files (jpg, png, etc.)\n"
            "  2. Has proper read permissions\n"
            "  3. Is not filtered out by any preprocessing"
        )

        super().__init__(message, suggestion=suggestion)


class InvalidImageError(DataError):
    """Raised when an image cannot be processed.

    This error occurs when:
    - Image file is corrupted or unreadable
    - Image has unsupported format or color space
    - Image dimensions are invalid

    Example:
        >>> raise InvalidImageError(
        ...     path="/path/to/image.jpg",
        ...     reason="File is corrupted"
        ... )
    """

    def __init__(
        self,
        path: str | None = None,
        *,
        reason: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            path: Path to the invalid image.
            reason: Specific reason why the image is invalid.
        """
        self.path = path
        self.reason = reason

        if path and reason:
            message = f"Invalid image '{path}': {reason}"
        elif path:
            message = f"Invalid image: '{path}'"
        elif reason:
            message = f"Invalid image: {reason}"
        else:
            message = "Invalid image provided."

        suggestion = (
            "Ensure the image:\n"
            "  1. Is a valid image file (jpg, png, etc.)\n"
            "  2. Has 3 channels (RGB) or can be converted to RGB\n"
            "  3. Is not corrupted or truncated"
        )

        super().__init__(message, suggestion=suggestion)


class DeviceMismatchError(DataError):
    """Raised when tensors are on different devices.

    This error occurs when:
    - Model and input are on different devices
    - Baseline dataset tensors don't match input device

    Example:
        >>> raise DeviceMismatchError(
        ...     expected_device="cuda:0",
        ...     actual_device="cpu",
        ...     tensor_name="input"
        ... )
    """

    def __init__(
        self,
        expected_device: str,
        actual_device: str,
        *,
        tensor_name: str | None = None,
    ) -> None:
        """Initialize the exception.

        Args:
            expected_device: Device where tensor should be.
            actual_device: Device where tensor actually is.
            tensor_name: Name of the mismatched tensor.
        """
        self.expected_device = expected_device
        self.actual_device = actual_device
        self.tensor_name = tensor_name

        if tensor_name:
            message = (
                f"Device mismatch for '{tensor_name}': "
                f"expected {expected_device}, got {actual_device}."
            )
        else:
            message = f"Device mismatch: expected {expected_device}, got {actual_device}."

        suggestion = f"Move tensor to correct device: tensor = tensor.to('{expected_device}')"

        super().__init__(message, suggestion=suggestion)
