"""Factory function for quick baseline provider creation.

Provides convenient auto-detection of provider type from source.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from expected_gradcam.baselines.protocols import BaselineProvider


def baseline_from(
    source: str | Path | dict[str, Any],
    **kwargs: Any,
) -> "BaselineProvider":
    """Create a baseline provider with auto-detection.

    Automatically detects the appropriate provider type based on the source:
    - Directory path -> DirectoryProvider
    - .npy/.npz file -> CachedFeatureProvider
    - String (not a path) -> HuggingFaceProvider
    - Dict with "type" key -> Explicit provider configuration

    Args:
        source: Data source specification. Can be:
            - Path to directory containing images
            - Path to .npy/.npz cache file
            - HuggingFace dataset name (string)
            - Dict with "type" and provider-specific config
        **kwargs: Additional arguments passed to the provider.

    Returns:
        Configured BaselineProvider instance.

    Raises:
        ValueError: If source type cannot be determined.
        ProviderNotFoundError: If specified provider type not found.

    Examples::

        # From directory (auto-detected)
        provider = baseline_from("/data/imagenet/train")

        # From cache file (auto-detected)
        provider = baseline_from("/cache/features.npy")

        # From HuggingFace (auto-detected as non-path string)
        provider = baseline_from("imagenet-1k", split="train")

        # Explicit configuration
        provider = baseline_from({
            "type": "huggingface",
            "dataset_name": "imagenet-1k",
            "split": "train",
            "max_samples": 1000,
        })

        # ImageNet structure with kwargs
        provider = baseline_from(
            "/data/imagenet",
            provider_type="imagenet",
            split="train",
            balanced=True,
        )
    """
    from expected_gradcam.baselines.registry import get_registry

    registry = get_registry()

    # Handle dict configuration
    if isinstance(source, dict):
        return _from_dict(source, registry, **kwargs)

    # Convert to Path for path-based detection
    source_path = Path(source) if isinstance(source, str) else source

    # Check for explicit provider_type override
    provider_type = kwargs.pop("provider_type", None)
    if provider_type:
        return _create_with_type(provider_type, source, registry, **kwargs)

    # Auto-detect based on source
    return _auto_detect(source_path, registry, **kwargs)


def _from_dict(
    config: dict[str, Any],
    registry: Any,
    **kwargs: Any,
) -> "BaselineProvider":
    """Create provider from dictionary configuration."""
    config = config.copy()  # Don't mutate original

    provider_type = config.pop("type", None)
    if not provider_type:
        raise ValueError(
            "Dictionary config must include 'type' key. "
            "Valid types: " + ", ".join(registry.list_providers())
        )

    # Merge config with kwargs (kwargs take precedence)
    merged = {**config, **kwargs}
    return registry.create(provider_type, **merged)


def _create_with_type(
    provider_type: str,
    source: str | Path,
    registry: Any,
    **kwargs: Any,
) -> "BaselineProvider":
    """Create provider with explicit type."""
    # Map source to appropriate parameter based on type
    source_path = Path(source)

    type_to_param = {
        "directory": "path",
        "imagenet": "root",
        "cached": "cache_path",
        "huggingface": "dataset_name",
        "torch_dataset": None,  # Requires dataset object, not path
    }

    param_name = type_to_param.get(provider_type)
    if param_name is None:
        raise ValueError(
            f"Provider type '{provider_type}' does not accept path source. "
            "Use the builder pattern for complex configurations."
        )

    # Convert to appropriate type
    if provider_type == "huggingface":
        kwargs[param_name] = str(source)  # HF wants string
    else:
        kwargs[param_name] = source_path

    return registry.create(provider_type, **kwargs)


def _auto_detect(
    source: Path,
    registry: Any,
    **kwargs: Any,
) -> "BaselineProvider":
    """Auto-detect provider type from source."""
    source_str = str(source)

    # Check if it's a cache file
    if source.suffix in {".npy", ".npz"}:
        return registry.create("cached", cache_path=source, **kwargs)

    # Check if path exists as directory
    if source.exists() and source.is_dir():
        # Check for ImageNet structure (has train/val subdirs with class folders)
        if _looks_like_imagenet(source):
            return registry.create("imagenet", root=source, **kwargs)
        return registry.create("directory", path=source, **kwargs)

    # Check if it looks like a file path (has extension or parent exists)
    if source.suffix or (source.parent.exists() and source.parent != source):
        # Might be a path that doesn't exist yet - check for common cache extensions
        if source.suffix in {".npy", ".npz", ".pt", ".pth"}:
            return registry.create("cached", cache_path=source, **kwargs)
        # Otherwise treat as directory path
        return registry.create("directory", path=source, **kwargs)

    # Non-path string - assume HuggingFace dataset name
    return registry.create("huggingface", dataset_name=source_str, **kwargs)


def _looks_like_imagenet(path: Path) -> bool:
    """Check if directory looks like ImageNet structure."""
    # ImageNet has train/val subdirectories
    train_dir = path / "train"
    val_dir = path / "val"

    if not (train_dir.exists() or val_dir.exists()):
        return False

    # Check for class subdirectories (synset IDs like n01440764)
    check_dir = train_dir if train_dir.exists() else val_dir
    subdirs = [d for d in check_dir.iterdir() if d.is_dir()]

    if not subdirs:
        return False

    # ImageNet synset IDs start with 'n' followed by 8 digits
    import re

    synset_pattern = re.compile(r"^n\d{8}$")
    matching = sum(1 for d in subdirs[:10] if synset_pattern.match(d.name))

    # If most subdirs match synset pattern, it's ImageNet
    return matching >= len(subdirs[:10]) * 0.5


__all__ = ["baseline_from"]
