"""Model registration system for embedding models with capability metadata.

This registry stores both the model class and a small capability spec that callers can
query for discoverability and validation (e.g., shapes, zero-shot support).
"""

import logging
from collections.abc import Callable
from typing import TypedDict, TypeVar

from pathfmtools.embedding_models.embedding_model import EmbeddingModel

logger = logging.getLogger(__name__)


class _Capabilities(TypedDict, total=False):
    """Capabilities/specification for a registered embedding model."""

    embedding_dim: int
    zeroshot_dim: int | None
    supports_zeroshot: bool
    supports_text: bool


# Global registry to store model name -> (class, capabilities)
_MODEL_REGISTRY: dict[str, tuple[type[EmbeddingModel], _Capabilities]] = {}

T = TypeVar("T", bound=type[EmbeddingModel])


def register_model(
    name: str,
    *,
    embedding_dim: int,
    zeroshot_dim: int | None = None,
    supports_zeroshot: bool = False,
    supports_text: bool = False,
) -> Callable[[T], T]:
    """Register an embedding model class with capability metadata. Used as a decorator.

    Args:
        name: Canonical model name (e.g., "conch", "uni2").
        embedding_dim: Feature embedding dimensionality.
        zeroshot_dim: Zero-shot/text-aligned embedding dimensionality, if applicable.
        supports_zeroshot: Whether model produces zero-shot patch embeddings.
        supports_text: Whether model implements text embedding.
        trust_remote: Whether registration expects remote code loading (informational flag).

    Raises:
        ValueError: If a model with the same name is already registered.

    """

    def decorator(model_class: T) -> T:
        key = name.lower()
        if key in _MODEL_REGISTRY:
            existing_class, _ = _MODEL_REGISTRY[key]
            msg = f"Model '{name}' is already registered to {existing_class.__name__}"
            logger.error(msg)
            raise ValueError(msg)

        capabilities: _Capabilities = {
            "embedding_dim": int(embedding_dim),
            "zeroshot_dim": zeroshot_dim,
            "supports_zeroshot": bool(supports_zeroshot),
            "supports_text": bool(supports_text),
        }

        logger.debug(
            "Registering model '%s' -> %s | caps=%s", name, model_class.__name__, capabilities
        )
        _MODEL_REGISTRY[key] = (model_class, capabilities)
        return model_class

    return decorator


def get_embedding_model(model_name: str) -> type[EmbeddingModel]:
    """Return an EmbeddingModel class based on the given name.

    Args:
        model_name: The name of the embedding model.

    Returns:
        The corresponding EmbeddingModel class.

    Raises:
        ValueError: If the model name is not registered.

    """
    model_name_lower = model_name.lower()

    if model_name_lower not in _MODEL_REGISTRY:
        available_models = list_available_models()
        msg = f"Invalid embedding model: {model_name}. Available models: {available_models}"
        logger.error(msg)
        raise ValueError(msg)

    model_cls, _ = _MODEL_REGISTRY[model_name_lower]
    return model_cls


def list_available_models() -> list[str]:
    """Return a list of all registered model names sorted alphabetically."""
    return sorted(_MODEL_REGISTRY.keys())


def get_capabilities(model_name: str | None) -> dict:
    """Get capability metadata for one or all registered models.

    Args:
        model_name: Name of the model. If ``None``, returns all capabilities keyed
            by model name.

    Returns:
        - If ``model_name`` is provided: a shallow copy of the capability dict for
          that model.
        - If ``model_name`` is ``None``: a mapping of ``model_name -> capability dict``.

    """
    if model_name is None:
        # Return capabilities for all models; copy to avoid accidental mutation
        return {name: dict(caps) for name, (_, caps) in _MODEL_REGISTRY.items()}

    key = model_name.lower()
    if key not in _MODEL_REGISTRY:
        available = list_available_models()
        msg = f"Unknown model '{model_name}'. Available: {available}"
        raise ValueError(msg)
    _, capabilities = _MODEL_REGISTRY[key]
    # Return a copy to avoid accidental mutation
    return dict(capabilities)
