"""
InfoNCE loss for dense retrieval with support for multiple negative types.
"""

import logging
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

logger = logging.getLogger(__name__)


class AllGatherWithGrad(torch.autograd.Function):
    """
    All-gather operation with full gradient preservation across GPUs.

    Unlike standard all_gather, this ensures that gradients computed on any GPU
    flow back to the source GPU for its local tensors. This enables each document
    to receive gradients from queries on all GPUs, not just its local GPU.
    """

    @staticmethod
    def forward(ctx, tensor):
        """
        Gather tensors from all GPUs.

        Args:
            tensor: Input tensor of shape (batch_size, ...) or (batch_size, max_pos, emb_dim)

        Returns:
            Gathered tensor of shape (world_size * batch_size, ...) or
            (world_size * batch_size, max_pos, emb_dim)
        """
        if not dist.is_initialized():
            return tensor

        world_size = dist.get_world_size()
        rank = dist.get_rank()

        # Gather tensors from all GPUs
        output = [torch.empty_like(tensor) for _ in range(world_size)]
        dist.all_gather(output, tensor)

        # Save context for backward pass
        ctx.rank = rank
        ctx.world_size = world_size
        ctx.batch_size = tensor.shape[0]

        # Concatenate along batch dimension
        return torch.cat(output, dim=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        Distribute gradients back to respective GPUs.

        The grad_output has shape (world_size * batch_size, ...).
        Each GPU should receive gradients for its own batch portion.

        Args:
            grad_output: Gradients w.r.t. the gathered output

        Returns:
            Gradients w.r.t. the input tensor on this GPU
        """
        # Extract gradients corresponding to this GPU's batch
        start_idx = ctx.rank * ctx.batch_size
        end_idx = (ctx.rank + 1) * ctx.batch_size

        return grad_output[start_idx:end_idx]


class InfoNCELoss(nn.Module):
    """
    InfoNCE (Normalized Temperature-scaled Cross Entropy) loss.

    Supports:
    - Multiple positives per query
    - Three types of negatives: mined, sampled, in-batch
    - Distributed training with gradient gathering for in-batch negatives

    For each positive p_i of query q, computes:
        loss_i = -log(exp(q·p_i / temp) / (exp(q·p_i / temp) + Σ_n exp(q·n / temp)))

    Final loss is the mean over all positives of all queries.
    """

    def __init__(
        self,
        temperature: float = 1.0,
        use_mined_negatives: bool = False,
        use_sampled_negatives: bool = False,
        use_inbatch_negatives: bool = True,
        gather_across_gpus: bool = True,
    ):
        """
        Initialize InfoNCE loss.

        Args:
            temperature: Temperature scaling factor
            use_mined_negatives: Whether to use mined negatives from qrels
            use_sampled_negatives: Whether to use sampled negatives from negative sampler
            use_inbatch_negatives: Whether to use in-batch negatives
            gather_across_gpus: Whether to gather in-batch negatives across GPUs in DDP.
                If True, uses documents from all GPUs as negatives (default behavior).
                If False, uses only local GPU's documents as negatives.
        """
        super().__init__()

        self.temperature = temperature
        self.use_mined_negatives = use_mined_negatives
        self.use_sampled_negatives = use_sampled_negatives
        self.use_inbatch_negatives = use_inbatch_negatives
        self.gather_across_gpus = gather_across_gpus

        # Check that at least one negative type is enabled
        if not any([use_mined_negatives, use_sampled_negatives, use_inbatch_negatives]):
            raise ValueError("At least one type of negatives must be enabled!")

        logger.info(f"InfoNCE Loss initialized:")
        logger.info(f"  Temperature: {temperature}")
        logger.info(f"  Use mined negatives: {use_mined_negatives}")
        logger.info(f"  Use sampled negatives: {use_sampled_negatives}")
        logger.info(f"  Use in-batch negatives: {use_inbatch_negatives}")
        logger.info(f"  Gather across GPUs: {gather_across_gpus}")

        # Initialize the custom all_gather function
        self.all_gather_with_grad = AllGatherWithGrad.apply

    def forward(
        self,
        query_embeddings: torch.Tensor,
        positive_embeddings: torch.Tensor,
        positive_mask: torch.Tensor,
        mined_negative_embeddings: Optional[torch.Tensor] = None,
        mined_negative_mask: Optional[torch.Tensor] = None,
        sampled_negative_embeddings: Optional[torch.Tensor] = None,
        sampled_negative_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute InfoNCE loss.

        Args:
            query_embeddings: (batch_size, emb_dim)
            positive_embeddings: (batch_size, max_pos, emb_dim)
            positive_mask: (batch_size, max_pos) - True for real positives
            mined_negative_embeddings: (batch_size, max_mined_neg, emb_dim)
            mined_negative_mask: (batch_size, max_mined_neg) - True for real negatives
            sampled_negative_embeddings: (batch_size, max_sampled_neg, emb_dim)
            sampled_negative_mask: (batch_size, max_sampled_neg) - True for real negatives

        Returns:
            Scalar loss value
        """
        batch_size = query_embeddings.size(0)

        # Compute similarity scores between queries and their positives
        # (batch_size, max_pos)
        positive_scores = (
            torch.sum(query_embeddings.unsqueeze(1) * positive_embeddings, dim=-1)
            / self.temperature
        )

        # Mask out padding positives
        positive_scores = positive_scores.masked_fill(~positive_mask, float("-inf"))

        # Collect all negative scores
        all_negative_scores = []

        # 1. Mined negatives
        if self.use_mined_negatives and mined_negative_embeddings is not None:
            # (batch_size, max_mined_neg)
            mined_scores = (
                torch.sum(
                    query_embeddings.unsqueeze(1) * mined_negative_embeddings, dim=-1
                )
                / self.temperature
            )
            mined_scores = mined_scores.masked_fill(~mined_negative_mask, float("-inf"))
            all_negative_scores.append(mined_scores)

        # 2. Sampled negatives
        if self.use_sampled_negatives and sampled_negative_embeddings is not None:
            # (batch_size, max_sampled_neg)
            sampled_scores = (
                torch.sum(
                    query_embeddings.unsqueeze(1) * sampled_negative_embeddings, dim=-1
                )
                / self.temperature
            )
            sampled_scores = sampled_scores.masked_fill(
                ~sampled_negative_mask, float("-inf")
            )
            all_negative_scores.append(sampled_scores)

        # 3. In-batch negatives
        if self.use_inbatch_negatives:
            inbatch_scores = self._compute_inbatch_negative_scores(
                query_embeddings,
                positive_embeddings,
                positive_mask,
            )
            all_negative_scores.append(inbatch_scores)

        # Concatenate all negative scores: (batch_size, total_negatives)
        if len(all_negative_scores) > 0:
            all_negative_scores = torch.cat(all_negative_scores, dim=1)
        else:
            raise RuntimeError("No negatives found!")

        # Compute loss for each positive using vectorized operations
        # For each query and each of its positives, compute:
        # loss = -log(exp(pos_score) / (exp(pos_score) + sum(exp(neg_scores))))
        # = -(pos_score - logsumexp([pos_score] + [neg_scores]))

        # Get the maximum number of positives
        max_positives = positive_scores.size(1)

        # For each position, combine positive score with all negative scores
        # positive_scores: (batch_size, max_pos)
        # all_negative_scores: (batch_size, total_negatives)

        losses = []
        for pos_idx in range(max_positives):
            # Get scores for this positive position across all queries
            pos_scores = positive_scores[:, pos_idx]  # (batch_size,)
            pos_masks = positive_mask[:, pos_idx]  # (batch_size,)

            # For queries that have a valid positive at this position
            valid_queries = pos_masks  # (batch_size,)

            if not valid_queries.any():
                continue

            # Combine positive score with negative scores for valid queries
            # (num_valid_queries, 1 + total_negatives)
            combined_scores = torch.cat(
                [
                    pos_scores[valid_queries].unsqueeze(1),  # (num_valid, 1)
                    all_negative_scores[valid_queries],  # (num_valid, total_negatives)
                ],
                dim=1,
            )

            # Mask out -inf values by replacing with a very large negative number temporarily
            # to avoid issues with logsumexp, then filter properly
            # Actually, logsumexp handles -inf correctly, so we can use it directly

            # Compute logsumexp for denominator: (num_valid_queries,)
            log_denominator = torch.logsumexp(combined_scores, dim=1)

            # Compute loss: -(pos_score - log_denominator)
            pos_losses = -(pos_scores[valid_queries] - log_denominator)

            losses.append(pos_losses)

        if len(losses) == 0:
            raise RuntimeError("No valid positives found in batch!")

        # Concatenate all losses and average
        total_loss = torch.cat(losses).mean()

        return total_loss

    def _compute_inbatch_negative_scores(
        self,
        query_embeddings: torch.Tensor,
        positive_embeddings: torch.Tensor,
        positive_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute scores for in-batch negatives with optional DDP gathering.

        For each query, all documents (positives) from OTHER queries in the batch
        become negatives.

        If gather_across_gpus=True (default): In distributed training, gather
        documents from all GPUs with gradient preservation.

        If gather_across_gpus=False: Use only local GPU's documents as negatives,
        no gathering across GPUs.

        Args:
            query_embeddings: (batch_size, emb_dim)
            positive_embeddings: (batch_size, max_pos, emb_dim)
            positive_mask: (batch_size, max_pos)

        Returns:
            In-batch negative scores: (batch_size, num_inbatch_negatives)
        """
        batch_size, max_pos, emb_dim = positive_embeddings.shape

        # Determine whether to gather across GPUs
        should_gather = self.gather_across_gpus and dist.is_initialized()

        if should_gather:
            # Gather embeddings and masks from all GPUs
            world_size = dist.get_world_size()
            rank = dist.get_rank()

            # Gather embeddings with full gradient preservation
            try:
                all_positive_embeddings = self._gather_with_grad(positive_embeddings)
            except RuntimeError as e:
                logger.error(
                    f"[Rank {rank}] all_gather failed for positive_embeddings. "
                    f"Shape: {positive_embeddings.shape}. "
                    f"This usually means batches have different sizes across ranks. "
                    f"Ensure drop_last=True in DataLoader. Error: {e}"
                )
                raise

            # Gather masks (no gradient needed)
            gathered_masks = [
                torch.zeros_like(positive_mask) for _ in range(world_size)
            ]
            try:
                dist.all_gather(gathered_masks, positive_mask)
            except RuntimeError as e:
                logger.error(
                    f"[Rank {rank}] all_gather failed for positive_mask. "
                    f"Shape: {positive_mask.shape}. "
                    f"This usually means batches have different sizes across ranks. "
                    f"Ensure drop_last=True in DataLoader. Error: {e}"
                )
                raise
            all_positive_mask = torch.cat(gathered_masks, dim=0)

            # Get current rank to identify which queries are from this GPU
            local_batch_start = rank * batch_size
        else:
            # No gathering: use local batch only
            all_positive_embeddings = positive_embeddings
            all_positive_mask = positive_mask
            local_batch_start = 0

        # Flatten all documents: (total_batch * max_pos, emb_dim)
        total_batch_size = all_positive_embeddings.size(0)
        flat_doc_embeddings = all_positive_embeddings.view(
            -1, emb_dim
        )  # (total_batch * max_pos, emb_dim)
        flat_doc_mask = all_positive_mask.view(-1)  # (total_batch * max_pos,)

        # Compute scores between queries and ALL documents
        # (batch_size, total_batch * max_pos)
        all_scores = (
            torch.matmul(query_embeddings, flat_doc_embeddings.t()) / self.temperature
        )

        # Mask out padding documents
        all_scores = all_scores.masked_fill(~flat_doc_mask.unsqueeze(0), float("-inf"))

        # Mask out each query's OWN positives (they shouldn't be negatives)
        # For each query i (in local batch), its positives are at indices:
        # [(local_batch_start + i) * max_pos : (local_batch_start + i + 1) * max_pos]
        for i in range(batch_size):
            global_query_idx = local_batch_start + i
            own_doc_start = global_query_idx * max_pos
            own_doc_end = own_doc_start + max_pos
            all_scores[i, own_doc_start:own_doc_end] = float("-inf")

        return all_scores

    def _gather_with_grad(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        All-gather tensors across GPUs with full gradient preservation.

        Uses custom autograd function to ensure gradients flow back properly.
        Each GPU's documents will receive gradients from queries on ALL GPUs.

        Args:
            tensor: Tensor to gather (batch_size, ...) or (batch_size, max_pos, emb_dim)

        Returns:
            Gathered tensor from all GPUs (world_size * batch_size, ...)
        """
        if not dist.is_initialized():
            return tensor

        # Use custom all_gather with proper gradient flow
        return self.all_gather_with_grad(tensor)
