"""Baseline provider descriptor for configuration.

Provides a validated descriptor that accepts various baseline provider
specifications and creates the appropriate provider instance.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, overload

from expected_gradcam.config.descriptors import ValidatedDescriptor

if TYPE_CHECKING:
    from expected_gradcam.baselines.protocols import BaselineProvider


class BaselineProviderParam(ValidatedDescriptor["BaselineProvider | None"]):
    """Descriptor for baseline provider configuration.

    Accepts various input formats and creates the appropriate provider:
    - None: No provider (synthetic baselines)
    - BaselineProvider instance: Used directly
    - str/Path (directory): Creates DirectoryProvider
    - str/Path (.npy/.npz): Creates CachedFeatureProvider
    - str (non-path): Creates HuggingFaceProvider
    - dict with "type" key: Creates specified provider

    Example:
        >>> class MyConfig(BaseConfig):
        ...     baseline_provider = BaselineProviderParam(
        ...         None,
        ...         doc="Provider for data-aware baselines"
        ...     )
        ...
        >>> config = MyConfig()
        >>> config.baseline_provider = "/data/imagenet/train"  # DirectoryProvider
        >>> config.baseline_provider = "imagenet-1k"  # HuggingFaceProvider
        >>> config.baseline_provider = {"type": "cached", "cache_path": "features.npy"}
    """

    def __init__(
        self,
        default: "BaselineProvider | None" = None,
        *,
        doc: str = "",
    ) -> None:
        """Initialize the descriptor.

        Args:
            default: Default provider (usually None for synthetic baselines).
            doc: Documentation string for the parameter.
        """
        super().__init__(default, doc=doc)

    @overload
    def __get__(
        self, obj: None, objtype: type
    ) -> "BaselineProviderParam": ...

    @overload
    def __get__(
        self, obj: object, objtype: type
    ) -> "BaselineProvider | None": ...

    def __get__(
        self, obj: object | None, objtype: type | None = None
    ) -> "BaselineProviderParam | BaselineProvider | None":
        """Get the provider value."""
        if obj is None:
            return self
        return obj.__dict__.get(self._name, self.default)

    def __set__(
        self,
        obj: object,
        value: "BaselineProvider | str | Path | dict[str, Any] | None",
    ) -> None:
        """Set the provider with automatic type conversion.

        Args:
            obj: The config instance.
            value: Provider specification (various formats accepted).

        Raises:
            ValueError: If value cannot be converted to a provider.
        """
        if value is None:
            obj.__dict__[self._name] = None
            return

        # Already a provider
        from expected_gradcam.baselines.protocols import BaselineProvider

        if isinstance(value, BaselineProvider):
            obj.__dict__[self._name] = value
            return

        # Convert to provider using factory
        provider = self._create_provider(value)
        obj.__dict__[self._name] = provider

    def _create_provider(
        self,
        value: str | Path | dict[str, Any],
    ) -> "BaselineProvider":
        """Create provider from specification.

        Args:
            value: Provider specification.

        Returns:
            Created BaselineProvider instance.

        Raises:
            ValueError: If specification is invalid.
        """
        from expected_gradcam.baselines.factory import baseline_from

        try:
            return baseline_from(value)
        except Exception as e:
            raise ValueError(
                f"Failed to create baseline provider from {value!r}: {e}"
            ) from e

    def _validate(self, value: "BaselineProvider | None") -> None:
        """Validate provider (protocol compliance is checked on use)."""
        # No validation needed - protocol compliance checked when used
        pass

    def __repr__(self) -> str:
        """Return string representation."""
        return f"BaselineProviderParam(default={self.default})"


__all__ = ["BaselineProviderParam"]
