"""Variational Autoencoder (VAE) for SMILES sequences.

This module implements a Transformer-based VAE that encodes SMILES sequences
into latent representations and decodes them back to sequences.
"""

from dataclasses import dataclass
from typing import Optional, Tuple, Dict
import math
import torch
import torch.nn as nn


class TransformerEncoder(nn.Module):
    """Transformer encoder for sequence encoding.

    A stack of Transformer encoder layers with GELU activation and batch-first
    format for processing token sequences.

    Args:
        d_model: Model dimension (embedding size)
        nhead: Number of attention heads
        num_layers: Number of encoder layers
        dim_ff: Feedforward dimension
        dropout: Dropout probability
    """

    def __init__(self, d_model: int, nhead: int, num_layers: int, dim_ff: int, dropout: float):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=num_layers)

    def forward(self, x, src_key_padding_mask=None):
        """Forward pass through encoder.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model)
            src_key_padding_mask: Boolean mask for padding tokens

        Returns:
            Encoded tensor of shape (batch, seq_len, d_model)
        """
        return self.enc(x, src_key_padding_mask=src_key_padding_mask)


class TransformerDecoder(nn.Module):
    """Transformer decoder for autoregressive sequence generation.

    A stack of Transformer decoder layers that attend to both the target sequence
    and memory (latent representations) for conditional generation.

    Args:
        d_model: Model dimension (embedding size)
        nhead: Number of attention heads
        num_layers: Number of decoder layers
        dim_ff: Feedforward dimension
        dropout: Dropout probability
    """

    def __init__(self, d_model: int, nhead: int, num_layers: int, dim_ff: int, dropout: float):
        super().__init__()
        layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
        )
        self.dec = nn.TransformerDecoder(layer, num_layers=num_layers)

    def forward(
        self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None
    ):
        """Forward pass through decoder.

        Args:
            tgt: Target sequence tensor of shape (batch, tgt_len, d_model)
            memory: Memory tensor (e.g., latent encodings) of shape (batch, mem_len, d_model)
            tgt_mask: Causal mask for target sequence
            tgt_key_padding_mask: Boolean mask for target padding tokens
            memory_key_padding_mask: Boolean mask for memory padding tokens

        Returns:
            Decoded tensor of shape (batch, tgt_len, d_model)
        """
        return self.dec(
            tgt=tgt,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )


class TokenPooling(nn.Module):
    """
    Attention pooling from sequence encodings (B,L,d_model) to K pooled tokens (B,K,d_model),
    using K learned query vectors.
    """

    def __init__(self, d_model: int, K: int):
        super().__init__()
        self.K = K
        self.query = nn.Parameter(torch.randn(K, d_model) / math.sqrt(d_model))
        self.proj_q = nn.Linear(d_model, d_model)
        self.proj_k = nn.Linear(d_model, d_model)
        self.proj_v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(
        self, h: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        B, L, d = h.shape
        q = self.proj_q(self.query[None, :, :].expand(B, -1, -1))  # (B,K,d)
        k = self.proj_k(h)  # (B,L,d)
        v = self.proj_v(h)  # (B,L,d)
        att = torch.einsum("bkd,bld->bkl", q, k) / math.sqrt(d)  # (B,K,L)
        if key_padding_mask is not None:
            att = att.masked_fill(key_padding_mask[:, None, :], float("-inf"))
        w = torch.softmax(att, dim=-1)
        pooled = torch.einsum("bkl,bld->bkd", w, v)
        return self.out(pooled)


@dataclass
class VAEConfig:
    """Configuration for SmilesTokenVAE model.

    Attributes:
        vocab_size: Size of token vocabulary
        max_len: Maximum sequence length
        d_model: Model dimension (embedding size)
        nhead: Number of attention heads
        enc_layers: Number of encoder layers
        dec_layers: Number of decoder layers
        dim_ff: Feedforward dimension
        dropout: Dropout probability
        K: Number of latent tokens
        d_latent: Latent dimension per token
    """

    vocab_size: int
    max_len: int
    d_model: int = 256
    nhead: int = 8
    enc_layers: int = 6
    dec_layers: int = 6
    dim_ff: int = 1024
    dropout: float = 0.1
    K: int = 8
    d_latent: int = 128


class SmilesTokenVAE(nn.Module):
    """
    Sequence VAE with token-latent posterior q(z|x) where z is (K,d_latent).
    Decoder is a causal Transformer conditioned on latent tokens as memory.
    """

    def __init__(self, cfg: VAEConfig, pad_id: int):
        super().__init__()
        self.cfg = cfg
        self.pad_id = pad_id

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_len, cfg.d_model)

        self.enc = TransformerEncoder(
            cfg.d_model, cfg.nhead, cfg.enc_layers, cfg.dim_ff, cfg.dropout
        )
        self.pool = TokenPooling(cfg.d_model, cfg.K)
        self.to_mu = nn.Linear(cfg.d_model, cfg.d_latent)
        self.to_logvar = nn.Linear(cfg.d_model, cfg.d_latent)

        self.z_to_mem = nn.Linear(cfg.d_latent, cfg.d_model)
        self.dec = TransformerDecoder(
            cfg.d_model, cfg.nhead, cfg.dec_layers, cfg.dim_ff, cfg.dropout
        )
        self.out = nn.Linear(cfg.d_model, cfg.vocab_size)

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Encode input sequence into latent representation.

        Args:
            x: Input token IDs of shape (batch, seq_len)

        Returns:
            Tuple of (z, mu, logvar) where:
            - z: Sampled latent tokens of shape (batch, K, d_latent)
            - mu: Mean of posterior distribution
            - logvar: Log variance of posterior distribution
        """
        B, L = x.shape
        padmask = x == self.pad_id
        pos = torch.arange(L, device=x.device)[None, :].expand(B, -1)
        h = self.tok_emb(x) + self.pos_emb(pos)
        h = self.enc(h, src_key_padding_mask=padmask)
        pooled = self.pool(h, key_padding_mask=padmask)  # (B,K,d_model)
        mu = self.to_mu(pooled)
        logvar = self.to_logvar(pooled).clamp(-8.0, 8.0)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z, mu, logvar

    def decode_logits(self, x_in: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        """Decode latent representation to logits over vocabulary.

        Args:
            x_in: Input token IDs for decoder (typically x[:, :-1]) of shape (batch, seq_len-1)
            z: Latent tokens of shape (batch, K, d_latent)

        Returns:
            Logits over vocabulary of shape (batch, seq_len-1, vocab_size)
        """
        B, L = x_in.shape
        padmask = x_in == self.pad_id
        pos = torch.arange(L, device=x_in.device)[None, :].expand(B, -1)
        tgt = self.tok_emb(x_in) + self.pos_emb(pos)
        causal = torch.triu(torch.ones(L, L, device=x_in.device, dtype=torch.bool), diagonal=1)
        mem = self.z_to_mem(z)  # (B,K,d_model)
        h = self.dec(tgt=tgt, memory=mem, tgt_mask=causal, tgt_key_padding_mask=padmask)
        return self.out(h)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Forward pass through VAE.

        Args:
            x: Input token IDs of shape (batch, seq_len)

        Returns:
            Dictionary with keys:
            - logits: Decoder logits of shape (batch, seq_len-1, vocab_size)
            - x_tgt: Target tokens (x[:, 1:]) for loss computation
            - mu: Posterior mean of shape (batch, K, d_latent)
            - logvar: Posterior log variance
            - z: Sampled latent tokens
        """
        z, mu, logvar = self.encode(x)
        x_in = x[:, :-1]
        x_tgt = x[:, 1:]
        logits = self.decode_logits(x_in, z)
        return {"logits": logits, "x_tgt": x_tgt, "mu": mu, "logvar": logvar, "z": z}
