"""
Base class for batch samplers.
"""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any
import numpy as np

from ..monitoring import monitor_performance

logger = logging.getLogger(__name__)


class BaseBatchSampler(ABC):
    """
    Base class for batch sampling strategies.

    A batch sampler determines the order in which queries are processed during training.
    This can be random, or based on some strategy (e.g., curriculum learning, clustering, etc.).
    """

    def __init_subclass__(cls, **kwargs):
        """Automatically wrap sample() method with performance monitoring."""
        super().__init_subclass__(**kwargs)
        if hasattr(cls, 'sample') and callable(cls.sample):
            original_sample = cls.sample
            cls.sample = monitor_performance('batch_sampler')(original_sample)

    def __init__(self, seed: int = 42, **kwargs):
        """
        Initialize the batch sampler.

        Args:
            seed: Random seed for reproducibility
            **kwargs: Additional sampler-specific arguments
        """
        self.seed = seed
        self.rng = np.random.RandomState(seed)
        logger.info(f"Initialized {self.__class__.__name__} with seed {seed}")

    @abstractmethod
    def sample(
        self,
        query_ids: List[str],
        query_embeddings: Optional[np.ndarray] = None,
        doc_embeddings: Optional[np.ndarray] = None,
        doc_ids: Optional[List[str]] = None,
        qrels: Optional[Dict[str, Dict[str, int]]] = None,
        epoch: Optional[int] = None,
        **kwargs,
    ) -> List[str]:
        """
        Sample a batch ordering for queries.

        Args:
            query_ids: List of all query IDs to be ordered
            query_embeddings: Optional query embeddings (num_queries, embedding_dim)
                             Rows correspond to query_ids in the same order
            doc_embeddings: Optional document embeddings (num_docs, embedding_dim)
                           Rows correspond to doc_ids in the same order
            doc_ids: Optional list of document IDs corresponding to doc_embeddings
            qrels: Optional qrels dict mapping query_id -> {doc_id: relevance}
            epoch: Optional current epoch number (can be used for curriculum learning)
            **kwargs: Additional sampler-specific arguments

        Returns:
            Ordered list of query IDs for this epoch
        """
        pass

    def reset_seed(self, seed: int):
        """
        Reset the random seed.

        Args:
            seed: New random seed
        """
        self.seed = seed
        self.rng = np.random.RandomState(seed)
        logger.info(f"Reset seed to {seed}")

    def compute_similarity_scores(
        self,
        query_embeddings: np.ndarray,
        doc_embeddings: np.ndarray,
    ) -> np.ndarray:
        """
        Compute similarity scores between queries and documents.
        Helper method for samplers that need similarity computation.

        Args:
            query_embeddings: Query embeddings (num_queries, embedding_dim)
            doc_embeddings: Document embeddings (num_docs, embedding_dim)

        Returns:
            Similarity matrix (num_queries, num_docs)
        """
        # Assumes embeddings are already normalized (dot product = cosine similarity)
        scores = np.matmul(query_embeddings, doc_embeddings.T)
        return scores

    def create_id_to_index_mapping(self, ids: List[str]) -> Dict[str, int]:
        """
        Create mapping from ID to index.
        Helper method for efficient lookups.

        Args:
            ids: List of IDs

        Returns:
            Dict mapping id -> index
        """
        return {id_: idx for idx, id_ in enumerate(ids)}
