"""
Base class for negative samplers.
"""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any, Set
import numpy as np

from ..monitoring import monitor_performance

logger = logging.getLogger(__name__)


class BaseNegativeSampler(ABC):
    """
    Base class for negative sampling strategies.

    A negative sampler generates additional negative documents for each query
    beyond those already present in the qrels (mined negatives).
    """

    def __init_subclass__(cls, **kwargs):
        """Automatically wrap sample() method with performance monitoring."""
        super().__init_subclass__(**kwargs)
        if hasattr(cls, 'sample') and callable(cls.sample):
            original_sample = cls.sample
            cls.sample = monitor_performance('negative_sampler')(original_sample)

    def __init__(self, seed: int = 42, **kwargs):
        """
        Initialize the negative sampler.

        Args:
            seed: Random seed for reproducibility
            **kwargs: Additional sampler-specific arguments
        """
        self.seed = seed
        self.rng = np.random.RandomState(seed)
        logger.info(f"Initialized {self.__class__.__name__} with seed {seed}")

    @abstractmethod
    def sample(
        self,
        query_ids: List[str],
        doc_ids: List[str],
        query_embeddings: Optional[np.ndarray] = None,
        doc_embeddings: Optional[np.ndarray] = None,
        qrels: Optional[Dict[str, Dict[str, int]]] = None,
        num_samples: int = 1,
        epoch: Optional[int] = None,
        **kwargs,
    ) -> Dict[str, List[str]]:
        """
        Sample negative documents for each query.

        Args:
            query_ids: List of query IDs to sample negatives for
            doc_ids: List of all available document IDs
            query_embeddings: Optional query embeddings (num_queries, embedding_dim)
                             Rows correspond to query_ids in the same order
            doc_embeddings: Optional document embeddings (num_docs, embedding_dim)
                           Rows correspond to doc_ids in the same order
            qrels: Optional qrels dict mapping query_id -> {doc_id: relevance}
                   Used to avoid sampling positive documents as negatives
            num_samples: Number of negative samples per query
            epoch: Optional current epoch number (can be used for curriculum learning)
            **kwargs: Additional sampler-specific arguments

        Returns:
            Dict mapping query_id -> list of sampled negative doc_ids
        """
        pass

    def reset_seed(self, seed: int):
        """
        Reset the random seed.

        Args:
            seed: New random seed
        """
        self.seed = seed
        self.rng = np.random.RandomState(seed)
        logger.info(f"Reset seed to {seed}")

    def _get_positive_docs(
        self, query_id: str, qrels: Optional[Dict[str, Dict[str, int]]]
    ) -> Set[str]:
        """
        Get set of positive document IDs for a query.

        Args:
            query_id: Query ID
            qrels: Qrels dict

        Returns:
            Set of positive document IDs (empty set if no qrels)
        """
        if qrels is None or query_id not in qrels:
            return set()

        # Get documents with positive relevance (> 0)
        positive_docs = {doc_id for doc_id, rel in qrels[query_id].items() if rel > 0}
        return positive_docs

    def compute_similarity_scores(
        self,
        query_embeddings: np.ndarray,
        doc_embeddings: np.ndarray,
    ) -> np.ndarray:
        """
        Compute similarity scores between queries and documents.
        Helper method for hard negative mining strategies.

        Args:
            query_embeddings: Query embeddings (num_queries, embedding_dim)
            doc_embeddings: Document embeddings (num_docs, embedding_dim)

        Returns:
            Similarity matrix (num_queries, num_docs)
        """
        # Assumes embeddings are already normalized (dot product = cosine similarity)
        scores = np.matmul(query_embeddings, doc_embeddings.T)
        return scores

    def create_id_to_index_mapping(self, ids: List[str]) -> Dict[str, int]:
        """
        Create mapping from ID to index.
        Helper method for efficient lookups when working with embeddings.

        Args:
            ids: List of IDs

        Returns:
            Dict mapping id -> index
        """
        return {id_: idx for idx, id_ in enumerate(ids)}

    def sample_top_k_negatives(
        self,
        query_id: str,
        query_idx: int,
        doc_ids: List[str],
        similarity_scores: np.ndarray,
        qrels: Optional[Dict[str, Dict[str, int]]],
        k: int,
        exclude_positives: bool = True,
    ) -> List[str]:
        """
        Sample top-k hardest negatives for a query based on similarity scores.
        Helper method for hard negative mining strategies like ANCE.

        Args:
            query_id: Query ID
            query_idx: Index of query in similarity_scores matrix
            doc_ids: List of all document IDs
            similarity_scores: Similarity matrix (num_queries, num_docs)
            qrels: Qrels dict
            k: Number of negatives to sample
            exclude_positives: Whether to exclude positive documents

        Returns:
            List of top-k negative document IDs
        """
        # Get scores for this query
        query_scores = similarity_scores[query_idx]  # (num_docs,)

        # Get positive documents to exclude
        if exclude_positives:
            positive_docs = self._get_positive_docs(query_id, qrels)
        else:
            positive_docs = set()

        # Create list of (score, doc_id) for non-positive docs
        candidates = [
            (score, doc_id)
            for score, doc_id in zip(query_scores, doc_ids)
            if doc_id not in positive_docs
        ]

        # Sort by score (descending) and take top-k
        candidates.sort(reverse=True, key=lambda x: x[0])
        top_k_docs = [doc_id for _, doc_id in candidates[:k]]

        return top_k_docs
