"""
Random negative sampler implementation.
"""

import logging
from typing import List, Optional, Dict, Any
import numpy as np
from .base import BaseNegativeSampler

logger = logging.getLogger(__name__)


class RandomNegativeSampler(BaseNegativeSampler):
    """
    Random negative sampler that uniformly samples negative documents.

    Samples negatives uniformly at random from the document collection,
    avoiding positive documents for each query.
    """

    def __init__(self, seed: int = 42, **kwargs):
        """
        Initialize random negative sampler.

        Args:
            seed: Random seed for reproducibility
            **kwargs: Additional arguments (ignored for random sampler)
        """
        super().__init__(seed=seed, **kwargs)

    def sample(
        self,
        query_ids: List[str],
        doc_ids: List[str],
        query_embeddings: Optional[np.ndarray] = None,
        doc_embeddings: Optional[np.ndarray] = None,
        qrels: Optional[Dict[str, Dict[str, int]]] = None,
        num_samples: int = 1,
        epoch: Optional[int] = None,
        **kwargs,
    ) -> Dict[str, List[str]]:
        """
        Randomly sample negative documents for each query.

        Args:
            query_ids: List of query IDs to sample negatives for
            doc_ids: List of all available document IDs
            query_embeddings: Not used for random sampling
            doc_embeddings: Not used for random sampling
            qrels: Optional qrels to avoid sampling positive documents
            num_samples: Number of negative samples per query
            epoch: Not used for random sampling
            **kwargs: Additional arguments (ignored)

        Returns:
            Dict mapping query_id -> list of sampled negative doc_ids
        """
        logger.info(
            f"Randomly sampling {num_samples} negatives for {len(query_ids)} queries"
        )

        sampled_negatives = {}
        doc_ids_array = np.array(doc_ids)

        for query_id in query_ids:
            # Get positive documents for this query
            positive_docs = self._get_positive_docs(query_id, qrels)

            # Filter out positive documents from sampling pool
            if len(positive_docs) > 0:
                # Create mask for non-positive documents
                available_mask = np.array(
                    [doc_id not in positive_docs for doc_id in doc_ids]
                )
                available_doc_ids = doc_ids_array[available_mask]
            else:
                available_doc_ids = doc_ids_array

            # Check if we have enough documents to sample
            if len(available_doc_ids) < num_samples:
                logger.warning(
                    f"Query {query_id}: Only {len(available_doc_ids)} available docs, "
                    f"but {num_samples} samples requested. Sampling with replacement."
                )
                # Sample with replacement if not enough documents
                sampled = self.rng.choice(
                    available_doc_ids, size=num_samples, replace=True
                )
            else:
                # Sample without replacement
                sampled = self.rng.choice(
                    available_doc_ids, size=num_samples, replace=False
                )

            sampled_negatives[query_id] = sampled.tolist()

        logger.info(f"Sampled negatives for {len(sampled_negatives)} queries")

        return sampled_negatives
