"""Builder pattern for complex provider configurations.

Provides a fluent interface for configuring providers with
data sources, transforms, caching, and validation options.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Self

if TYPE_CHECKING:
    from torch import Tensor

    from expected_gradcam.baselines.protocols import BaselineProvider


class BaselineProviderBuilder:
    """Builder for creating baseline providers with complex configuration.

    Provides a fluent interface for configuring providers with:
    - Data source specification (directory, HuggingFace, dataset)
    - Transform configuration
    - Caching options
    - Validation settings

    Example::

        provider = (
            BaselineProviderBuilder()
            .from_directory("/data/imagenet/train")
            .with_extensions(".jpg", ".png")
            .with_max_images(1000)
            .with_cache("/cache/features.npy")
            .with_transform(my_transform)
            .build()
        )

        # Or using HuggingFace
        provider = (
            BaselineProviderBuilder()
            .from_huggingface("imagenet-1k")
            .with_split("train")
            .with_streaming()
            .build()
        )
    """

    def __init__(self) -> None:
        """Initialize the builder."""
        self._provider_type: str | None = None
        self._config: dict[str, Any] = {}

    # =========================================================================
    # Data Source Methods
    # =========================================================================

    def from_directory(self, path: str | Path) -> Self:
        """Use a directory as data source.

        Args:
            path: Path to directory containing images.

        Returns:
            Self for chaining.
        """
        self._provider_type = "directory"
        self._config["path"] = Path(path)
        return self

    def from_imagenet(
        self,
        root: str | Path,
        split: str = "train",
    ) -> Self:
        """Use ImageNet directory structure.

        Args:
            root: Root ImageNet directory.
            split: Dataset split ("train" or "val").

        Returns:
            Self for chaining.
        """
        self._provider_type = "imagenet"
        self._config["root"] = Path(root)
        self._config["split"] = split
        return self

    def from_huggingface(self, dataset_name: str) -> Self:
        """Use a HuggingFace dataset.

        Args:
            dataset_name: Name of the dataset on HuggingFace Hub.

        Returns:
            Self for chaining.
        """
        self._provider_type = "huggingface"
        self._config["dataset_name"] = dataset_name
        return self

    def from_torch_dataset(self, dataset: Any) -> Self:
        """Use an existing PyTorch dataset.

        Args:
            dataset: PyTorch Dataset instance.

        Returns:
            Self for chaining.
        """
        self._provider_type = "torch_dataset"
        self._config["dataset"] = dataset
        return self

    def from_cache(self, cache_path: str | Path) -> Self:
        """Use pre-extracted feature cache.

        Args:
            cache_path: Path to numpy cache file.

        Returns:
            Self for chaining.
        """
        self._provider_type = "cached"
        self._config["cache_path"] = Path(cache_path)
        return self

    # =========================================================================
    # Configuration Methods
    # =========================================================================

    def with_extensions(self, *extensions: str) -> Self:
        """Set file extensions for directory provider.

        Args:
            *extensions: File extensions (e.g., ".jpg", ".png").

        Returns:
            Self for chaining.
        """
        self._config["extensions"] = extensions
        return self

    def with_max_images(self, n: int) -> Self:
        """Limit number of images to load.

        Args:
            n: Maximum number of images.

        Returns:
            Self for chaining.
        """
        # Map to correct parameter name based on provider
        self._config["max_images"] = n
        self._config["max_samples"] = n
        return self

    def with_split(self, split: str) -> Self:
        """Set dataset split.

        Args:
            split: Dataset split ("train", "val", "test").

        Returns:
            Self for chaining.
        """
        self._config["split"] = split
        return self

    def with_streaming(self, enabled: bool = True) -> Self:
        """Enable streaming mode for large datasets.

        Args:
            enabled: Whether to enable streaming.

        Returns:
            Self for chaining.
        """
        self._config["streaming"] = enabled
        return self

    def with_cache(self, cache_path: str | Path) -> Self:
        """Set cache path for saving/loading features.

        Args:
            cache_path: Path to cache file.

        Returns:
            Self for chaining.
        """
        self._config["cache_path"] = Path(cache_path)
        return self

    def with_transform(
        self,
        transform: Callable[[Any], "Tensor"],
    ) -> Self:
        """Set custom image preprocessing transform.

        Args:
            transform: Transform function.

        Returns:
            Self for chaining.
        """
        self._config["transform"] = transform
        return self

    def with_shuffle(self, enabled: bool = True) -> Self:
        """Enable data shuffling.

        Args:
            enabled: Whether to shuffle data.

        Returns:
            Self for chaining.
        """
        self._config["shuffle"] = enabled
        return self

    def with_batch_size(self, batch_size: int) -> Self:
        """Set batch size for feature extraction.

        Args:
            batch_size: Batch size.

        Returns:
            Self for chaining.
        """
        self._config["batch_size"] = batch_size
        return self

    def with_balanced_sampling(self, enabled: bool = True) -> Self:
        """Enable balanced sampling across classes (ImageNet).

        Args:
            enabled: Whether to balance sampling.

        Returns:
            Self for chaining.
        """
        self._config["balanced"] = enabled
        return self

    def with_image_column(self, column: str) -> Self:
        """Set image column name (HuggingFace datasets).

        Args:
            column: Name of the image column.

        Returns:
            Self for chaining.
        """
        self._config["image_column"] = column
        return self

    def with_mmap_mode(self, mode: str) -> Self:
        """Set memory-map mode for cached features.

        Args:
            mode: Memory map mode ('r', 'r+', 'c').

        Returns:
            Self for chaining.
        """
        self._config["mmap_mode"] = mode
        return self

    # =========================================================================
    # Build Method
    # =========================================================================

    def build(self) -> "BaselineProvider":
        """Build the configured provider.

        Returns:
            Configured BaselineProvider instance.

        Raises:
            ValueError: If no data source specified.
        """
        if self._provider_type is None:
            raise ValueError(
                "No data source specified. Use from_directory(), "
                "from_huggingface(), from_cache(), etc."
            )

        from expected_gradcam.baselines.registry import get_registry

        registry = get_registry()

        # Filter config to only include relevant parameters
        filtered_config = self._get_filtered_config()

        return registry.create(self._provider_type, **filtered_config)

    def _get_filtered_config(self) -> dict[str, Any]:
        """Get config filtered for the current provider type."""
        # Common parameters
        common = {"cache_path", "transform", "shuffle"}

        # Provider-specific parameters
        provider_params = {
            "directory": {"path", "extensions", "max_images"},
            "imagenet": {"root", "split", "max_images", "balanced"},
            "huggingface": {
                "dataset_name",
                "split",
                "image_column",
                "streaming",
                "max_samples",
                "trust_remote_code",
            },
            "torch_dataset": {"dataset", "max_samples", "batch_size"},
            "cached": {"cache_path", "mmap_mode"},
        }

        allowed = common | provider_params.get(self._provider_type, set())

        return {k: v for k, v in self._config.items() if k in allowed}

    def to_dict(self) -> dict[str, Any]:
        """Convert builder configuration to dictionary.

        Useful for serialization or passing to config.

        Returns:
            Dictionary with provider type and config.
        """
        result = {"type": self._provider_type}
        result.update(self._config)
        return result

    def __repr__(self) -> str:
        """Return string representation."""
        parts = ["BaselineProviderBuilder("]
        if self._provider_type:
            parts.append(f"type='{self._provider_type}'")
            if self._config:
                parts.append(f", config_keys={list(self._config.keys())}")
        else:
            parts.append("not configured")
        parts.append(")")
        return "".join(parts)


__all__ = ["BaselineProviderBuilder"]
