"""
Dual Encoder model for dense retrieval.
"""

import logging
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig

logger = logging.getLogger(__name__)


class DualEncoder(nn.Module):
    """
    Siamese Dual Encoder model for dense retrieval.

    Uses a shared encoder (from HuggingFace) to encode both queries and documents.
    Supports different pooling strategies (CLS token or mean pooling).
    Optionally normalizes embeddings.
    """

    def __init__(
        self,
        encoder_name: str = "distilbert-base-uncased",
        pooling_strategy: str = "cls",
        embedding_dim: int = 768,
        normalize_embeddings: bool = True,
    ):
        """
        Initialize the dual encoder model.

        Args:
            encoder_name: HuggingFace model name or path
            pooling_strategy: "cls" or "mean"
            embedding_dim: Expected embedding dimension (for validation)
            normalize_embeddings: Whether to L2 normalize embeddings
        """
        super().__init__()

        self.encoder_name = encoder_name
        self.pooling_strategy = pooling_strategy
        self.embedding_dim = embedding_dim
        self.normalize_embeddings = normalize_embeddings

        # Load pre-trained encoder
        logger.info(f"Loading encoder: {encoder_name}")
        self.encoder = AutoModel.from_pretrained(encoder_name)
        # self.embedding_head = nn.Linear(self.encoder.config.hidden_size, embedding_dim)
        # self.norm = nn.LayerNorm(embedding_dim)

        # Verify embedding dimension
        config = AutoConfig.from_pretrained(encoder_name)
        actual_dim = config.hidden_size
        if actual_dim != embedding_dim:
            logger.warning(
                f"Encoder hidden size ({actual_dim}) doesn't match "
                f"specified embedding_dim ({embedding_dim}). Using encoder's dimension."
            )
            self.embedding_dim = actual_dim

        logger.info(f"Model initialized with {self.pooling_strategy} pooling")
        logger.info(f"Embedding dimension: {self.embedding_dim}")
        logger.info(f"Normalize embeddings: {self.normalize_embeddings}")

    def _pool_embeddings(
        self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Pool token embeddings to get sequence embedding.

        Args:
            hidden_states: (batch_size, seq_len, hidden_dim)
            attention_mask: (batch_size, seq_len)

        Returns:
            Pooled embeddings: (batch_size, hidden_dim)
        """
        if self.pooling_strategy == "cls":
            # Use CLS token (first token)
            return hidden_states[:, 0, :]

        elif self.pooling_strategy == "mean":
            # Mean pooling with attention mask
            # Expand attention mask to match hidden_states dimensions
            attention_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            )

            # Sum of embeddings for non-padding tokens
            sum_embeddings = torch.sum(hidden_states * attention_mask_expanded, dim=1)

            # Sum of attention mask (number of non-padding tokens)
            sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)

            # Average
            return sum_embeddings / sum_mask

        else:
            raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")

    def encode(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode text inputs to embeddings.

        Args:
            input_ids: (batch_size, seq_len)
            attention_mask: (batch_size, seq_len)

        Returns:
            Embeddings: (batch_size, embedding_dim)
        """
        # Forward pass through encoder
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )

        # Apply embedding head
        embeddings = outputs.last_hidden_state
        # embeddings = self.embedding_head(embeddings)
        # embeddings = self.norm(embeddings)

        # Pool to get sequence-level embeddings
        embeddings = self._pool_embeddings(embeddings, attention_mask)

        # Normalize if requested
        if self.normalize_embeddings:
            embeddings = F.normalize(embeddings, p=2, dim=-1)

        return embeddings

    def forward(
        self,
        query_input_ids: torch.Tensor,
        query_attention_mask: torch.Tensor,
        positive_input_ids: Optional[torch.Tensor] = None,
        positive_attention_mask: Optional[torch.Tensor] = None,
        positive_mask: Optional[torch.Tensor] = None,
        mined_negative_input_ids: Optional[torch.Tensor] = None,
        mined_negative_attention_mask: Optional[torch.Tensor] = None,
        mined_negative_mask: Optional[torch.Tensor] = None,
        sampled_negative_input_ids: Optional[torch.Tensor] = None,
        sampled_negative_attention_mask: Optional[torch.Tensor] = None,
        sampled_negative_mask: Optional[torch.Tensor] = None,
    ) -> dict:
        """
        Forward pass for training.

        Args:
            query_input_ids: (batch_size, max_query_len)
            query_attention_mask: (batch_size, max_query_len)
            positive_input_ids: (batch_size, max_positives, max_doc_len)
            positive_attention_mask: (batch_size, max_positives, max_doc_len)
            positive_mask: (batch_size, max_positives) - True for real, False for padding
            mined_negative_input_ids: (batch_size, max_mined_negs, max_doc_len)
            mined_negative_attention_mask: (batch_size, max_mined_negs, max_doc_len)
            mined_negative_mask: (batch_size, max_mined_negs)
            sampled_negative_input_ids: (batch_size, max_sampled_negs, max_doc_len)
            sampled_negative_attention_mask: (batch_size, max_sampled_negs, max_doc_len)
            sampled_negative_mask: (batch_size, max_sampled_negs)

        Returns:
            Dict with:
                - query_embeddings: (batch_size, embedding_dim)
                - positive_embeddings: (batch_size, max_positives, embedding_dim)
                - positive_mask: (batch_size, max_positives)
                - mined_negative_embeddings: (batch_size, max_mined_negs, embedding_dim)
                - mined_negative_mask: (batch_size, max_mined_negs)
                - sampled_negative_embeddings: (batch_size, max_sampled_negs, embedding_dim)
                - sampled_negative_mask: (batch_size, max_sampled_negs)
        """
        batch_size = query_input_ids.size(0)

        # Encode queries
        query_embeddings = self.encode(query_input_ids, query_attention_mask)

        output = {
            "query_embeddings": query_embeddings,
        }

        # Encode positives
        if positive_input_ids is not None:
            max_positives = positive_input_ids.size(1)
            if max_positives > 0:
                # Flatten: (batch_size, max_positives, max_doc_len) -> (batch_size * max_positives, max_doc_len)
                flat_input_ids = positive_input_ids.view(
                    -1, positive_input_ids.size(-1)
                )
                flat_attention_mask = positive_attention_mask.view(
                    -1, positive_attention_mask.size(-1)
                )

                # Encode
                flat_embeddings = self.encode(flat_input_ids, flat_attention_mask)

                # Reshape back: (batch_size * max_positives, embedding_dim) -> (batch_size, max_positives, embedding_dim)
                positive_embeddings = flat_embeddings.view(
                    batch_size, max_positives, -1
                )

                output["positive_embeddings"] = positive_embeddings
                output["positive_mask"] = positive_mask
            else:
                # No positives
                output["positive_embeddings"] = torch.zeros(
                    (batch_size, 0, self.embedding_dim),
                    device=query_embeddings.device,
                    dtype=query_embeddings.dtype,
                )
                output["positive_mask"] = torch.zeros(
                    (batch_size, 0), device=query_embeddings.device, dtype=torch.bool
                )

        # Encode mined negatives
        if mined_negative_input_ids is not None:
            max_mined_negs = mined_negative_input_ids.size(1)
            if max_mined_negs > 0:
                flat_input_ids = mined_negative_input_ids.view(
                    -1, mined_negative_input_ids.size(-1)
                )
                flat_attention_mask = mined_negative_attention_mask.view(
                    -1, mined_negative_attention_mask.size(-1)
                )

                flat_embeddings = self.encode(flat_input_ids, flat_attention_mask)
                mined_negative_embeddings = flat_embeddings.view(
                    batch_size, max_mined_negs, -1
                )

                output["mined_negative_embeddings"] = mined_negative_embeddings
                output["mined_negative_mask"] = mined_negative_mask
            else:
                output["mined_negative_embeddings"] = torch.zeros(
                    (batch_size, 0, self.embedding_dim),
                    device=query_embeddings.device,
                    dtype=query_embeddings.dtype,
                )
                output["mined_negative_mask"] = torch.zeros(
                    (batch_size, 0), device=query_embeddings.device, dtype=torch.bool
                )

        # Encode sampled negatives
        if sampled_negative_input_ids is not None:
            max_sampled_negs = sampled_negative_input_ids.size(1)
            if max_sampled_negs > 0:
                flat_input_ids = sampled_negative_input_ids.view(
                    -1, sampled_negative_input_ids.size(-1)
                )
                flat_attention_mask = sampled_negative_attention_mask.view(
                    -1, sampled_negative_attention_mask.size(-1)
                )

                flat_embeddings = self.encode(flat_input_ids, flat_attention_mask)
                sampled_negative_embeddings = flat_embeddings.view(
                    batch_size, max_sampled_negs, -1
                )

                output["sampled_negative_embeddings"] = sampled_negative_embeddings
                output["sampled_negative_mask"] = sampled_negative_mask
            else:
                output["sampled_negative_embeddings"] = torch.zeros(
                    (batch_size, 0, self.embedding_dim),
                    device=query_embeddings.device,
                    dtype=query_embeddings.dtype,
                )
                output["sampled_negative_mask"] = torch.zeros(
                    (batch_size, 0), device=query_embeddings.device, dtype=torch.bool
                )

        return output

    def encode_queries(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Encode queries for embedding generation.

        Args:
            input_ids: (batch_size, seq_len)
            attention_mask: (batch_size, seq_len)

        Returns:
            Query embeddings: (batch_size, embedding_dim)
        """
        return self.encode(input_ids, attention_mask)

    def encode_documents(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Encode documents for embedding generation.

        Args:
            input_ids: (batch_size, seq_len)
            attention_mask: (batch_size, seq_len)

        Returns:
            Document embeddings: (batch_size, embedding_dim)
        """
        return self.encode(input_ids, attention_mask)

    def save_pretrained(self, save_directory: str):
        """
        Save model to directory.

        Args:
            save_directory: Directory to save the model
        """
        import os

        os.makedirs(save_directory, exist_ok=True)

        # Save encoder
        self.encoder.save_pretrained(save_directory)

        # Save config
        config = {
            "encoder_name": self.encoder_name,
            "pooling_strategy": self.pooling_strategy,
            "embedding_dim": self.embedding_dim,
            "normalize_embeddings": self.normalize_embeddings,
        }

        import json

        with open(os.path.join(save_directory, "model_config.json"), "w") as f:
            json.dump(config, f, indent=2)

        logger.info(f"Model saved to {save_directory}")

    @classmethod
    def from_pretrained(cls, load_directory: str):
        """
        Load model from directory.

        Args:
            load_directory: Directory to load the model from

        Returns:
            DualEncoder instance
        """
        import json
        import os

        # Load config
        config_path = os.path.join(load_directory, "model_config.json")
        with open(config_path, "r") as f:
            config = json.load(f)

        # Create model instance
        model = cls(
            encoder_name=load_directory,  # Load encoder from this directory
            pooling_strategy=config["pooling_strategy"],
            embedding_dim=config["embedding_dim"],
            normalize_embeddings=config["normalize_embeddings"],
        )

        # But preserve the original encoder name for tokenizer loading
        model.original_encoder_name = config.get(
            "encoder_name", "distilbert-base-uncased"
        )

        logger.info(f"Model loaded from {load_directory}")
        return model
