"""
Factory functions for creating samplers from configuration.
"""

import logging
from typing import Dict, Any
from .batch_sampler import (
    BaseBatchSampler,
    RandomBatchSampler,
    HOBITBatchSampler,
)
from .negative_sampler import (
    BaseNegativeSampler,
    RandomNegativeSampler
)

logger = logging.getLogger(__name__)


# Registry of available batch samplers
BATCH_SAMPLER_REGISTRY = {
    "random": RandomBatchSampler,
    "hobit": HOBITBatchSampler,
}


# Registry of available negative samplers
NEGATIVE_SAMPLER_REGISTRY = {
    "random": RandomNegativeSampler,
}


def create_batch_sampler(
    name: str,
    seed: int = 42,
    args: Dict[str, Any] = None,
) -> BaseBatchSampler:
    """
    Create a batch sampler from configuration.

    Args:
        name: Name of the batch sampler (must be in BATCH_SAMPLER_REGISTRY)
        seed: Random seed
        args: Additional sampler-specific arguments

    Returns:
        Instantiated batch sampler

    Raises:
        ValueError: If sampler name is not recognized
    """
    if name not in BATCH_SAMPLER_REGISTRY:
        available = ", ".join(BATCH_SAMPLER_REGISTRY.keys())
        raise ValueError(
            f"Unknown batch sampler: {name}. Available samplers: {available}"
        )

    sampler_class = BATCH_SAMPLER_REGISTRY[name]
    sampler_args = args or {}

    logger.info(f"Creating batch sampler: {name}")
    sampler = sampler_class(seed=seed, **sampler_args)

    return sampler


def create_negative_sampler(
    name: str,
    seed: int = 42,
    args: Dict[str, Any] = None,
) -> BaseNegativeSampler:
    """
    Create a negative sampler from configuration.

    Args:
        name: Name of the negative sampler (must be in NEGATIVE_SAMPLER_REGISTRY)
        seed: Random seed
        args: Additional sampler-specific arguments

    Returns:
        Instantiated negative sampler

    Raises:
        ValueError: If sampler name is not recognized
    """
    if name not in NEGATIVE_SAMPLER_REGISTRY:
        available = ", ".join(NEGATIVE_SAMPLER_REGISTRY.keys())
        raise ValueError(
            f"Unknown negative sampler: {name}. Available samplers: {available}"
        )

    sampler_class = NEGATIVE_SAMPLER_REGISTRY[name]
    sampler_args = args or {}

    logger.info(f"Creating negative sampler: {name}")
    sampler = sampler_class(seed=seed, **sampler_args)

    return sampler


def register_batch_sampler(name: str, sampler_class: type):
    """
    Register a new batch sampler class.

    Args:
        name: Name to register the sampler under
        sampler_class: Batch sampler class (subclass of BaseBatchSampler)
    """
    if not issubclass(sampler_class, BaseBatchSampler):
        raise ValueError(f"{sampler_class} must be a subclass of BaseBatchSampler")

    BATCH_SAMPLER_REGISTRY[name] = sampler_class
    logger.info(f"Registered batch sampler: {name}")


def register_negative_sampler(name: str, sampler_class: type):
    """
    Register a new negative sampler class.

    Args:
        name: Name to register the sampler under
        sampler_class: Negative sampler class (subclass of BaseNegativeSampler)
    """
    if not issubclass(sampler_class, BaseNegativeSampler):
        raise ValueError(f"{sampler_class} must be a subclass of BaseNegativeSampler")

    NEGATIVE_SAMPLER_REGISTRY[name] = sampler_class
    logger.info(f"Registered negative sampler: {name}")
