"""
Random batch sampler implementation.
"""

import logging
from typing import List, Optional, Dict, Any
import numpy as np
from .base import BaseBatchSampler

logger = logging.getLogger(__name__)


class RandomBatchSampler(BaseBatchSampler):
    """
    Random batch sampler that randomly shuffles query order.

    This is the simplest batch sampling strategy and serves as a baseline.
    """

    def __init__(self, seed: int = 42, **kwargs):
        """
        Initialize random batch sampler.

        Args:
            seed: Random seed for reproducibility
            **kwargs: Additional arguments (ignored for random sampler)
        """
        super().__init__(seed=seed, **kwargs)

    def sample(
        self,
        query_ids: List[str],
        query_embeddings: Optional[np.ndarray] = None,
        doc_embeddings: Optional[np.ndarray] = None,
        doc_ids: Optional[List[str]] = None,
        qrels: Optional[Dict[str, Dict[str, int]]] = None,
        epoch: Optional[int] = None,
        **kwargs,
    ) -> List[str]:
        """
        Randomly shuffle query order.

        Args:
            query_ids: List of all query IDs to be ordered
            query_embeddings: Not used for random sampling
            doc_embeddings: Not used for random sampling
            doc_ids: Not used for random sampling
            qrels: Not used for random sampling
            epoch: Not used for random sampling
            **kwargs: Additional arguments (ignored)

        Returns:
            Randomly shuffled list of query IDs
        """
        logger.info(f"Randomly shuffling {len(query_ids)} queries")

        # Create a copy to avoid modifying the original
        shuffled_ids = query_ids.copy()

        # Shuffle in place
        self.rng.shuffle(shuffled_ids)

        logger.info(f"Shuffled query order: first 5 = {shuffled_ids[:5]}")

        return shuffled_ids
