import math

import torch
import torch.nn.functional as F  # noqa: N812
from torch import nn
from transformers import AutoModel, AutoTokenizer

from .config import EBMConfig
from .utils import batch_semantic_sentence_split, normalize_sentences


class SentenceEmbedder(nn.Module):
    """Sentence embedding model that converts text sentences to dense vectors.

        Supports two embedding approaches:
        - SentenceBERT: Uses mean pooling of token embeddings with
        attention masking and L2 normalization
        - BERT_CLS: Uses the [CLS] token embedding from BERT

    Args:
            model_type (str): Type of embedding model to use.
                Either "SentenceBERT" or "BERT_CLS". Defaults to "SentenceBERT".

    Returns:
            torch.Tensor or None: Sentence embeddings of
                shape [num_valid_sentences, embedding_dim].
                Returns None if all input sentences are None.

    Note:
        Input sentences that are None are filtered out before processing.
        Only valid (non-None) sentences are embedded and returned.

    """

    def __init__(self, model_type: str = "SentenceBERT") -> None:
        super().__init__()
        if model_type == "SentenceBERT":
            # Possible SentenceBERT models:  https://www.sbert.net/docs/sentence_transformer/pretrained_models.html
            self.tokenizer = AutoTokenizer.from_pretrained(
                "sentence-transformers/all-mpnet-base-v2"
            )  # Embedding Dim: 768
            self.model = AutoModel.from_pretrained(
                "sentence-transformers/all-mpnet-base-v2"
            )  # Embedding Dim: 768
        elif model_type == "BERT_CLS":
            self.tokenizer = AutoTokenizer.from_pretrained(
                "bert-base-uncased"
            )  # Embedding Dim: 768
            self.model = AutoModel.from_pretrained(
                "bert-base-uncased"
            )  # Embedding Dim: 768
        else:
            msg = "Invalid model type"
            raise ValueError(msg)
        self.model_type = model_type

    def get_embedding_dim(self) -> int:
        """Return the embedding dimension of the selected model."""
        return self.model.config.hidden_size

    def forward(self, sentences: list[str | None]) -> torch.Tensor | None:
        if all(s is None for s in sentences):
            return None

        valid_sentences = [s for s in sentences if s is not None]
        if not valid_sentences:
            return None

        # Batch tokenization is already efficient in transformers
        tokens = self.tokenizer(
            valid_sentences, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=512  # Add explicit max_length for consistent padding
        )
        
        # Move tokens to the same device as the model
        device = next(self.model.parameters()).device
        tokens = {k: v.to(device) for k, v in tokens.items()}

        with torch.no_grad():
            output = self.model(**tokens)

        if self.model_type == "SentenceBERT":
            # Sentence-BERT style, altered usage based on model's guide at https://huggingface.co/sentence-transformers/all-mpnet-base-v2
            embeddings = output[0]
            attention_mask = tokens["attention_mask"]
            
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
            sum_embeddings = torch.sum(embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            embeddings = sum_embeddings / sum_mask
            
            # L2 normalization            
            embeddings = F.normalize(embeddings, p=2, dim=1)
        else:
            embeddings = output.last_hidden_state[:, 0]  # BERT CLS token

        return embeddings


class PlaceholderEmbedding(nn.Module):
    """Generates placeholder embeddings for padding sentence sequences.

    Creates either learnable or static placeholder embeddings that can be used
    to pad sentence sequences to a fixed length when there are fewer sentences
    than required.

    Args:
        placeholder_type (str, optional): Type of placeholder embedding.
            Either "learnable" for trainable parameters or "static" for fixed values.
            Defaults to "learnable".
        init_type (str, optional): Initialization method for learnable placeholders.
            Either "random" for random normal initialization or "zero" for zero
            initialization. Only used when placeholder_type is "learnable".
            Defaults to "random".

    Returns:
        torch.Tensor: Placeholder embeddings of shape [n, embedding_dim] where
            n is the number of placeholder embeddings requested.

    Note:
        The embedding dimension is determined by the global embedding_dim variable.
        Static placeholders are always zero-initialized and registered as buffers.

    """

    def __init__(
        self,
        embedding_dim: int,
        placeholder_type: str | None = "learnable",
        init_type: str | None = "random",
    ) -> None:
        super().__init__()
        self.placeholder_type = placeholder_type
        if placeholder_type == "learnable":
            if init_type == "random":
                self.placeholder = nn.Parameter(torch.randn(embedding_dim))
            elif init_type == "zero":
                self.placeholder = nn.Parameter(torch.zeros(embedding_dim))
        else: # static
            self.register_buffer("placeholder", torch.zeros(embedding_dim))

    def forward(self, n: int = 1) -> torch.Tensor:
        return self.placeholder.unsqueeze(0).repeat(n, 1)


class SentenceEncoder(nn.Module):
    """End-to-end sentence encoder that converts text to fixed-length embeddings.

    Processes input text by splitting into sentences, normalizing to a fixed length,
    and generating embeddings for each sentence position using either actual sentence
    embeddings or learnable placeholder embeddings for padding.

    Args:
        model_type (str): Type of sentence embedding model to use.
            Either "SentenceBERT" or "BERT_CLS". Defaults to "SentenceBERT".
        placeholder_type (str, optional): Type of placeholder embedding for padding.
            Either "learnable" or "static". Defaults to "learnable".
        placeholder_init_type (str, optional): Initialization method for learnable
            placeholders. Either "random" or "zero". Defaults to "random".
        n_sentences (int): Number of sentences to normalize text to.

    Returns:
        torch.Tensor: Fixed-length sentence embeddings of shape [n_sentences, embedding_dim]
            Each position contains either an actual sentence embedding or a placeholder.

    Note:
        Text is automatically split into sentences using spaCy, then normalized to
        exactly n_sentences length by combining excess sentences or adding placeholders.

    """

    def __init__(
        self,
        model_type: str = "SentenceBERT",
        placeholder_type: str | None = "learnable",
        placeholder_init_type: str | None = "random",
        n_sentences: int = 16,
    ) -> None:
        super().__init__()
        self.embedder = SentenceEmbedder(model_type)
        self.placeholder = PlaceholderEmbedding(
            embedding_dim=self.embedder.get_embedding_dim(),
            placeholder_type=placeholder_type,
            init_type=placeholder_init_type,
        )
        self.n_sentences = n_sentences

    def forward(self, texts: list[str]) -> tuple[torch.Tensor, torch.BoolTensor]:
        # Batch process all sentence splitting on the CPU for speed
        sentences_per_text = batch_semantic_sentence_split(texts)

        # Prepare lists to collect the results for each text
        list_of_embedding_tensors = []
        list_of_mask_lists = []
        
        # This flag ensures we only create the placeholder on GPU once if needed
        placeholder_gpu = None 
        device = next(self.parameters()).device

        # Loop through texts to process them one-by-one on the GPU to save memory
        for sentences in sentences_per_text:
            normalized_sents = normalize_sentences(sentences, self.n_sentences)

            # Remember which positions are placeholders
            pad_mask: list[bool] = [s is None for s in normalized_sents]

            # Get embeddings for just this text's valid sentences
            valid_sent_emb = self.embedder([s for s in normalized_sents if s is not None])
            
            # Ensure the placeholder is on the correct device if we need it
            if valid_sent_emb is not None:
                device = valid_sent_emb.device
            if placeholder_gpu is None or placeholder_gpu.device != device:
                 placeholder_gpu = self.placeholder(1).to(device)

            # Reconstruct the tensor for this single text
            emb_list_one_text: list[torch.Tensor] = []
            valid_idx = 0
            
            for is_pad in pad_mask:
                if is_pad:
                    emb_list_one_text.append(placeholder_gpu)
                else:
                    if valid_sent_emb is not None:
                        emb_list_one_text.append(valid_sent_emb[valid_idx].unsqueeze(0))
                        valid_idx += 1
            
            embeddings_one_text = torch.cat(emb_list_one_text, dim=0)
            
            # Append the results for this single text to our batch lists
            list_of_embedding_tensors.append(embeddings_one_text)
            list_of_mask_lists.append(pad_mask)

        # Stack the collected tensors and lists into final batch tensors
        final_embeddings = torch.stack(list_of_embedding_tensors, dim=0)
        final_masks = torch.tensor(list_of_mask_lists, dtype=torch.bool, device=final_embeddings.device)

        return final_embeddings, final_masks


class PositionalEncoding(nn.Module):
    """Adds positional encoding to input embeddings for transformer models.

    Provides either learnable or sinusoidal positional encodings to help the model
    understand sequence order information.

    Args:
        d_model (int): Model dimension/embedding size.
        max_len (int): Maximum sequence length. Defaults to n_sentences.
        encoding_type (str): Type of encoding - "learnable" or "sinusoidal".

    Returns:
        torch.Tensor: Input tensor with positional encoding added.

    """

    def __init__(
        self, d_model: int, max_len: int = 16, encoding_type: str = "sinusoidal"
    ) -> None:
        super().__init__()
        self.encoding_type = encoding_type
        if encoding_type == "learnable":
            self.pos_embedding = nn.Parameter(torch.randn(max_len, d_model))
        elif encoding_type == "sinusoidal":
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(
                torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
            )
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            self.register_buffer("pos_embedding", pe)
        else:
            msg = "Invalid encoding type"
            raise ValueError(msg)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x: [batch, seq_len, d_model]
        return x + self.pos_embedding[: x.size(1)]


class SelfAttentionEncoder(nn.Module):
    """Self-attention encoder for processing sentence embeddings.

    Applies positional encoding and multi-layer transformer self-attention
    to input sentence embeddings, allowing the model to capture relationships
    between sentences within a sequence.

    Args:
        d_model (int): Model dimension/embedding size. Defaults to embedding_dim.
        n_layers (int): Number of encoder layers. Defaults to self_attention_layers.
        n_heads (int): Number of attention heads. Defaults to attention_heads.
        dropout (float): Dropout rate for regularization. Defaults to dropout_rate.
        positional_encoding_type (str): Type of positional encoding to use.
            Defaults to positional_encoding_type.

    Returns:
        torch.Tensor: Self-attended embeddings of shape [batch, seq_len, d_model].

    """

    def __init__(  # noqa: PLR0913
        self,
        d_model: int,
        n_layers: int,
        n_heads: int,
        dropout: float,
        n_sentences: int,
        positional_encoding_type: str = "sinusoidal",
    ) -> None:
        super().__init__()
        self.pos_encoder = PositionalEncoding(
            d_model, n_sentences, positional_encoding_type
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

    def forward(
        self, x: torch.Tensor, pad_mask: torch.BoolTensor | None = None
    ) -> torch.Tensor:
        # Here x: [batch, seq_len, d_model]
        x = self.pos_encoder(x)
        return self.encoder(x, src_key_padding_mask=pad_mask)  # same shape


class CrossAttentionBlock(nn.Module):
    """Single cross-attention block for processing query-key-value attention.

    Implements a transformer-style cross-attention layer where the query sequence
    attends to a separate key-value sequence, followed by feed-forward processing
    and residual connections with layer normalization.

    Args:
        d_model (int): Model dimension/embedding size. Defaults to embedding_dim.
        n_heads (int): Number of attention heads. Defaults to attention_heads.
        dropout (float): Dropout rate for regularization. Defaults to dropout_rate.

    Returns:
        torch.Tensor: Cross-attended output of shape [batch, seq_len, d_model].

    Note:
        The query sequence attends to the key_value sequence, allowing information
        flow from key_value to query while maintaining the query's sequence structure.

    """

    def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(
        self,
        query: torch.Tensor,  # [B, Lq, d]
        key_value: torch.Tensor,  # [B, Lk, d]
        query_mask: torch.BoolTensor | None,
        kv_mask: torch.BoolTensor | None,
    ) -> torch.Tensor:
        attn_out, _ = self.attn(query, key_value, key_value, key_padding_mask=kv_mask)
        x = self.norm1(query + attn_out)
        ff_out = self.ff(x)
        out = self.norm2(x + ff_out)

        # zero-out positions that were padding in the *query*
        if query_mask is not None:
            out = out.masked_fill(query_mask.unsqueeze(-1), 0.0)
        return out


class CrossAttentionEncoder(nn.Module):
    """Multi-layer cross-attention encoder for processing query-key-value relationships.

    Applies multiple layers of cross-attention where query sequences attend to
    key-value sequences, enabling information flow between different input modalities
    or sequences while preserving the query sequence structure.

    Args:
        d_model (int): Model dimension/embedding size. Defaults to embedding_dim.
        n_layers (int): Number of cross-attention layers. Defaults to cross_attention_layers
        n_heads (int): Number of attention heads per layer. Defaults to attention_heads.
        dropout (float): Dropout rate for regularization. Defaults to dropout_rate.

    Returns:
        torch.Tensor: Cross-attended output embeddings of shape [batch, seq_len, d_model].

    Note:
        The output_emb (query) attends to input_emb (key-value) across all layers,
        allowing the output sequence to gather relevant information from the input
        sequence while maintaining its own positional structure.

    """

    def __init__(self, d_model: int, n_layers: int, n_heads: int, dropout: float) -> None:
        super().__init__()
        self.layers = nn.ModuleList(
            [CrossAttentionBlock(d_model, n_heads, dropout) for _ in range(n_layers)]
        )

    def forward(
        self,
        query: torch.Tensor,
        key_value: torch.Tensor,
        query_mask: torch.BoolTensor | None,
        kv_mask: torch.BoolTensor | None,
    ) -> torch.Tensor:
        for layer in self.layers:
            query = layer(query, key_value, query_mask, kv_mask)
        return query


class AttentionPooling(nn.Module):
    """Attention-based pooling layer for aggregating sequence embeddings.

    Uses a learned attention mechanism to pool variable-length sequences into
    fixed-size representations by computing attention weights over the sequence
    dimension and taking a weighted average.

    Args:
        d_model (int): Model dimension/embedding size. Defaults to embedding_dim.

    Returns:
        torch.Tensor: Pooled representation of shape [batch, d_model].

    """

    def __init__(self, d_model: int) -> None:
        super().__init__()
        self.query = nn.Linear(d_model, 1)

    def forward(
        self, x: torch.Tensor, pad_mask: torch.BoolTensor | None = None
    ) -> torch.Tensor:
        logits = self.query(x).squeeze(-1)  # [B, L]
        if pad_mask is not None:
            logits = logits.masked_fill(pad_mask, -1e9)
        weights = F.softmax(logits, dim=-1).unsqueeze(-1)  # [B, L, 1]
        return torch.sum(x * weights, dim=1)  # [B, d]


class EnergyHead(nn.Module):
    """Energy computation head for processing pooled sequence embeddings.

    Computes energy scores from sequence embeddings using either attention-based
    pooling or flattening, followed by a multi-layer perceptron. Lower energy
    scores indicate better matches in the energy-based model framework.

    Args:
        d_model (int): Model dimension/embedding size. Defaults to embedding_dim.
        pooling_type (str): Pooling strategy - "attention" or "flatten".
            Defaults to pooling_type.
        mlp_layers (int): Number of MLP layers for energy computation.
            Defaults to mlp_layers.
        hidden_factor (int): Hidden dimension multiplier for MLP layers.
            Defaults to hidden_dim_factor.
        activation (str): Activation function - "ReLU" or "GELU".
            Defaults to activation_fn.

    Returns:
        torch.Tensor: Energy scores of shape [batch] where lower values
            indicate better input-output compatibility.

    """

    def __init__(  # noqa: PLR0913
        self,
        d_model: int,
        pooling_type: str,
        mlp_layers: int,
        hidden_factor: int,
        activation: str = "GELU",
        n_sentences: int = 16,
    ) -> None:
        super().__init__()
        self.pooling_type = pooling_type
        if pooling_type == "flatten":
            self.pool = None
            in_dim = d_model * n_sentences
        else: # attention
            self.pool = AttentionPooling(d_model)
            in_dim = d_model

        layers = (
            [nn.Linear(in_dim, 1)]
            if mlp_layers == 1
            else [
                nn.Linear(in_dim, hidden_factor * d_model),
                nn.ReLU() if activation.lower() == "relu" else nn.GELU(),
                nn.Linear(hidden_factor * d_model, 1),
            ]
        )
        self.mlp = nn.Sequential(*layers)

    def forward(
        self, x: torch.Tensor, pad_mask: torch.BoolTensor | None = None
    ) -> torch.Tensor:
        if self.pooling_type == "flatten":
            pooled = x.masked_fill(pad_mask.unsqueeze(-1), 0.0).view(x.size(0), -1)
        else:
            pooled = self.pool(x, pad_mask)
        return self.mlp(pooled).squeeze(-1)  # [B]


class EnergyModel(nn.Module):
    def __init__(self, config: EBMConfig) -> None:
        super().__init__()
        self.text_encoder = SentenceEncoder(
            model_type=config.text_encoder_model_type,
            placeholder_type=config.placeholder_type,
            placeholder_init_type=config.placeholder_init_type,
            n_sentences=config.n_sentences,
        )
        config.d_model = self.text_encoder.embedder.get_embedding_dim()

        self.x_sa = SelfAttentionEncoder(
            d_model=config.d_model,
            n_layers=config.self_attention_n_layers,
            n_heads=config.attention_n_heads,
            dropout=config.dropout_rate,
            n_sentences=config.n_sentences,
        )
        self.y_sa = SelfAttentionEncoder(
            d_model=config.d_model,
            n_layers=config.self_attention_n_layers,
            n_heads=config.attention_n_heads,
            dropout=config.dropout_rate,
            n_sentences=config.n_sentences,
        )
        self.cross = CrossAttentionEncoder(
            d_model=config.d_model,
            n_layers=config.cross_attention_n_layers,
            n_heads=config.attention_n_heads,
            dropout=config.dropout_rate,
        )
        self.head = EnergyHead(
            d_model=config.d_model,
            pooling_type=config.energy_head_pooling_type,
            mlp_layers=config.energy_head_mlp_layers,
            hidden_factor=config.energy_head_hidden_factor,
            activation=config.activation_fn,
            n_sentences=config.n_sentences,
        )

    def forward(self, x_texts: list[str], y_texts: list[str]) -> torch.Tensor:
        # Parallelize encoding of x and y texts by combining them
        all_texts = x_texts + y_texts
        all_emb, all_masks = self.text_encoder(all_texts)

        # Split back into x and y components
        batch_size = len(x_texts)
        x_emb, x_mask = all_emb[:batch_size], all_masks[:batch_size]
        y_emb, y_mask = all_emb[batch_size:], all_masks[batch_size:]

        x_proc = self.x_sa(x_emb, x_mask)
        y_proc = self.y_sa(y_emb, y_mask)

        y_cross = self.cross(
            query=y_proc,
            key_value=x_proc,
            query_mask=y_mask,
            kv_mask=x_mask,
        )
        return self.head(y_cross, y_mask)  # [B]

    def forward_from_encoded(
        self,
        x_encoded: tuple[torch.Tensor, torch.Tensor],
        y_encoded: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        """
        Performs the forward pass starting from pre-encoded batch tensors.
        This is the primary method for cached and memory-safe computation.
        """
        x_emb_batch, x_mask_batch = x_encoded
        y_emb_batch, y_mask_batch = y_encoded

        # Move tensors from cache (CPU) to the model's active device
        device = next(self.parameters()).device
        x_emb = x_emb_batch.to(device)
        x_mask = x_mask_batch.to(device)
        y_emb = y_emb_batch.to(device)
        y_mask = y_mask_batch.to(device)

        x_proc = self.x_sa(x_emb, x_mask)
        y_proc = self.y_sa(y_emb, y_mask)

        y_cross = self.cross(
            query=y_proc,
            key_value=x_proc,
            query_mask=y_mask,
            kv_mask=x_mask,
        )
        return self.head(y_cross, y_mask)
    
    def forward_with_concept_masking(
        self, 
        x_texts: list[str],
        y_texts: list[str],
        concept_mask: torch.Tensor
    ) -> torch.Tensor:
        """Forward pass with concept-space masking applied after self-attention.
        This preserves semantic context during masking operations.
        """
        # Parallelize encoding
        all_texts = x_texts + y_texts
        all_emb, all_masks = self.text_encoder(all_texts)

        # Split back into x and y components
        batch_size = len(x_texts)
        x_emb, x_mask = all_emb[:batch_size], all_masks[:batch_size]
        y_emb, y_mask = all_emb[batch_size:], all_masks[batch_size:]

        x_proc = self.x_sa(x_emb, x_mask)
        y_proc = self.y_sa(y_emb, y_mask)

        # Apply concept-space masking to y_proc before cross-attention
        # concept_mask shape: [batch_size, n_sentences] or [batch_size, n_sentences, 1]
        if concept_mask.dim() == 2:
            concept_mask = concept_mask.unsqueeze(-1)  # [B, L, 1]

        y_proc_masked = y_proc * concept_mask

        y_cross = self.cross(
            query=y_proc_masked,
            key_value=x_proc,
            query_mask=y_mask,
            kv_mask=x_mask,
        )
        return self.head(y_cross, y_mask)  # [B]