"""
Data collator for batching dense retrieval training examples.
"""

import logging
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)


class DenseRetrievalCollator:
    """
    Collator for batching dense retrieval training examples.

    Handles:
    - Tokenization of queries and documents
    - Padding of variable-length sequences
    - Creating attention masks
    - Handling variable numbers of positives and negatives per query
    """

    def __init__(
        self,
        tokenizer: AutoTokenizer,
        max_query_length: int = 64,
        max_doc_length: int = 128,
        use_mined_negatives: bool = False,
        use_sampled_negatives: bool = False,
        query_prefix: str = "",
        document_prefix: str = "",
    ):
        """
        Initialize the collator.

        Args:
            tokenizer: HuggingFace tokenizer
            max_query_length: Maximum query length in tokens
            max_doc_length: Maximum document length in tokens
            use_mined_negatives: Whether to include mined negatives
            use_sampled_negatives: Whether to include sampled negatives
            query_prefix: Prefix to prepend to queries (e.g., "query: " for e5 models)
            document_prefix: Prefix to prepend to documents (e.g., "passage: " for e5 models)
        """
        self.tokenizer = tokenizer
        self.max_query_length = max_query_length
        self.max_doc_length = max_doc_length
        self.use_mined_negatives = use_mined_negatives
        self.use_sampled_negatives = use_sampled_negatives
        self.query_prefix = query_prefix
        self.document_prefix = document_prefix

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Collate a batch of examples.

        Args:
            batch: List of examples from DenseRetrievalDataset

        Returns:
            Collated batch dict with:
                - query_input_ids: (batch_size, max_query_len)
                - query_attention_mask: (batch_size, max_query_len)
                - positive_input_ids: (batch_size, max_positives, max_doc_len)
                - positive_attention_mask: (batch_size, max_positives, max_doc_len)
                - positive_mask: (batch_size, max_positives) - 1 for real, 0 for padding
                - mined_negative_input_ids: (batch_size, max_mined_negs, max_doc_len)
                - mined_negative_attention_mask: (batch_size, max_mined_negs, max_doc_len)
                - mined_negative_mask: (batch_size, max_mined_negs) - 1 for real, 0 for padding
                - sampled_negative_input_ids: (batch_size, max_sampled_negs, max_doc_len)
                - sampled_negative_attention_mask: (batch_size, max_sampled_negs, max_doc_len)
                - sampled_negative_mask: (batch_size, max_sampled_negs) - 1 for real, 0 for padding
                - query_ids: List[str]
        """
        batch_size = len(batch)

        # Extract data
        query_ids = [ex["query_id"] for ex in batch]
        query_texts = [ex["query_text"] for ex in batch]

        # Apply query prefix if specified
        if self.query_prefix:
            query_texts = [self.query_prefix + text for text in query_texts]

        # Tokenize queries
        query_encodings = self.tokenizer(
            query_texts,
            max_length=self.max_query_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        # Process positives
        max_positives = max(len(ex["positive_doc_texts"]) for ex in batch)
        (
            positive_input_ids,
            positive_attention_mask,
            positive_mask,
            positive_doc_ids,
        ) = self._collate_documents(
            batch,
            text_key="positive_doc_texts",
            id_key="positive_doc_ids",
            max_docs=max_positives,
        )

        # Initialize output dict
        collated = {
            "query_input_ids": query_encodings["input_ids"],
            "query_attention_mask": query_encodings["attention_mask"],
            "positive_input_ids": positive_input_ids,
            "positive_attention_mask": positive_attention_mask,
            "positive_mask": positive_mask,
            "positive_doc_ids": positive_doc_ids,
            "query_ids": query_ids,
        }

        # Process mined negatives if enabled
        if self.use_mined_negatives:
            max_mined_negs = max(len(ex["mined_negative_doc_texts"]) for ex in batch)
            if max_mined_negs > 0:
                (
                    mined_input_ids,
                    mined_attention_mask,
                    mined_mask,
                    mined_negative_doc_ids,
                ) = self._collate_documents(
                    batch,
                    text_key="mined_negative_doc_texts",
                    id_key="mined_negative_doc_ids",
                    max_docs=max_mined_negs,
                )
                collated.update(
                    {
                        "mined_negative_input_ids": mined_input_ids,
                        "mined_negative_attention_mask": mined_attention_mask,
                        "mined_negative_mask": mined_mask,
                        "mined_negative_doc_ids": mined_negative_doc_ids,
                    }
                )
            else:
                # No mined negatives in this batch
                collated.update(
                    {
                        "mined_negative_input_ids": torch.zeros(
                            (batch_size, 0, self.max_doc_length), dtype=torch.long
                        ),
                        "mined_negative_attention_mask": torch.zeros(
                            (batch_size, 0, self.max_doc_length), dtype=torch.long
                        ),
                        "mined_negative_mask": torch.zeros(
                            (batch_size, 0), dtype=torch.bool
                        ),
                        "mined_negative_doc_ids": [[""] * 0 for _ in range(batch_size)],
                    }
                )

        # Process sampled negatives if enabled
        if self.use_sampled_negatives:
            max_sampled_negs = max(
                len(ex["sampled_negative_doc_texts"]) for ex in batch
            )
            if max_sampled_negs > 0:
                (
                    sampled_input_ids,
                    sampled_attention_mask,
                    sampled_mask,
                    sampled_negative_doc_ids,
                ) = self._collate_documents(
                    batch,
                    text_key="sampled_negative_doc_texts",
                    id_key="sampled_negative_doc_ids",
                    max_docs=max_sampled_negs,
                )
                collated.update(
                    {
                        "sampled_negative_input_ids": sampled_input_ids,
                        "sampled_negative_attention_mask": sampled_attention_mask,
                        "sampled_negative_mask": sampled_mask,
                        "sampled_negative_doc_ids": sampled_negative_doc_ids,
                    }
                )
            else:
                # No sampled negatives in this batch
                collated.update(
                    {
                        "sampled_negative_input_ids": torch.zeros(
                            (batch_size, 0, self.max_doc_length), dtype=torch.long
                        ),
                        "sampled_negative_attention_mask": torch.zeros(
                            (batch_size, 0, self.max_doc_length), dtype=torch.long
                        ),
                        "sampled_negative_mask": torch.zeros(
                            (batch_size, 0), dtype=torch.bool
                        ),
                        "sampled_negative_doc_ids": [[""] * 0 for _ in range(batch_size)],
                    }
                )

        return collated

    def _collate_documents(
        self,
        batch: List[Dict[str, Any]],
        *,
        text_key: str,
        id_key: str,
        max_docs: int,
    ) -> tuple:
        """
        Collate document texts with padding.

        Args:
            batch: List of examples
            text_key: Key to access document texts in each example
            max_docs: Maximum number of documents per example in this batch

        Returns:
            Tuple of (input_ids, attention_mask, doc_mask)
            - input_ids: (batch_size, max_docs, max_doc_len)
            - attention_mask: (batch_size, max_docs, max_doc_len)
            - doc_mask: (batch_size, max_docs) - True for real docs, False for padding
        """
        batch_size = len(batch)

        # Collect all document texts and create mask
        all_doc_texts = []
        all_doc_ids: List[str] = []
        doc_mask = torch.zeros((batch_size, max_docs), dtype=torch.bool)

        for i, ex in enumerate(batch):
            doc_texts = ex[text_key]
            doc_ids = ex[id_key]
            num_docs = len(doc_texts)

            if len(doc_ids) != num_docs:
                raise ValueError(
                    f"Mismatched lengths for {text_key} and {id_key}: "
                    f"{len(doc_texts)} vs {len(doc_ids)}"
                )

            # Add real documents
            all_doc_texts.extend(doc_texts)
            all_doc_ids.extend(doc_ids)
            doc_mask[i, :num_docs] = True

            # Add padding documents (empty strings)
            padding_needed = max_docs - num_docs
            all_doc_texts.extend([""] * padding_needed)
            all_doc_ids.extend([""] * padding_needed)

        # Apply document prefix if specified
        if self.document_prefix:
            all_doc_texts = [
                self.document_prefix + text if text else text
                for text in all_doc_texts
            ]

        # Tokenize all documents at once
        if len(all_doc_texts) > 0:
            doc_encodings = self.tokenizer(
                all_doc_texts,
                max_length=self.max_doc_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

            # Reshape to (batch_size, max_docs, max_doc_len)
            input_ids = doc_encodings["input_ids"].view(batch_size, max_docs, -1)
            attention_mask = doc_encodings["attention_mask"].view(
                batch_size, max_docs, -1
            )
        else:
            # No documents
            input_ids = torch.zeros(
                (batch_size, max_docs, self.max_doc_length), dtype=torch.long
            )
            attention_mask = torch.zeros(
                (batch_size, max_docs, self.max_doc_length), dtype=torch.long
            )

        # Reshape doc ids to (batch_size, max_docs)
        doc_ids_2d = [all_doc_ids[i * max_docs : (i + 1) * max_docs] for i in range(batch_size)]

        return input_ids, attention_mask, doc_mask, doc_ids_2d
