import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Optional, Type
from nvib_layer import DenoisingMultiheadAttention

@dataclass
class TransformerLayerConfig:
    """Global hyperparameters for a Transformer Layer."""

    num_heads: int = 8
    emb_dim_per_head: int = 16
    mlp_dim_factor: float = 4.0
    dropout_rate: float = 0.0
    attention_dropout_rate: float = 0.0
    use_bias: bool = False
    max_rows: int = 30
    max_cols: int = 30
    seq_len: int = None
    activation: str = "silu"
    dtype: torch.dtype = torch.float32
    emb_dim: int = field(init=False)

    def __post_init__(self):
        self.emb_dim = self.num_heads * self.emb_dim_per_head
        if self.emb_dim % self.num_heads != 0:
            raise ValueError("emb_dim must be divisible by num_heads")


@dataclass
class EncoderTransformerConfig:
    """Global hyperparameters for the Encoder Transformer."""

    transformer_layer: TransformerLayerConfig = field(default_factory=TransformerLayerConfig)
    vocab_size: int = 10
    output_vocab_size: int = 10
    num_layers: int = 2
    latent_dim: int = 32
    variational: bool = False
    max_rows: int = 30
    max_cols: int = 30
    latent_projection_bias: bool = False
    scaled_position_embeddings: bool = False
    dtype: torch.dtype = field(init=False)
    emb_dim: int = field(init=False)
    max_len: int = field(init=False)
    seq_len: int = None
    disable_sampling: bool = False  

    def __post_init__(self):
        self.dtype = self.transformer_layer.dtype
        self.emb_dim = self.transformer_layer.emb_dim
        self.max_len = self.max_rows * self.max_cols


@dataclass
class DecoderTransformerConfig:
    """Global hyperparameters for the Decoder Transformer."""

    transformer_layer: TransformerLayerConfig = field(default_factory=TransformerLayerConfig)
    vocab_size: int = 10
    output_vocab_size: int = 10
    num_layers: int = 2
    max_rows: int = 30
    max_cols: int = 30
    scaled_position_embeddings: bool = False
    next_position_embeddings: bool = True
    next_position_embeddings_new_input_embeds: bool = False
    logits_projection_bias: bool = False
    dtype: torch.dtype = field(init=False)
    emb_dim: int = field(init=False)
    max_len: int = field(init=False)
    seq_len: int = None

    def __post_init__(self):
        self.dtype = self.transformer_layer.dtype
        self.emb_dim = self.transformer_layer.emb_dim
        self.max_len = self.max_rows * self.max_cols


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block."""

    def __init__(self, config: TransformerLayerConfig):
        super().__init__()
        self.config = config
        self.mlp_dim = int(config.mlp_dim_factor * config.emb_dim)

        self.fc1 = nn.Linear(config.emb_dim, self.mlp_dim, bias=config.use_bias, dtype=config.dtype)
        self.fc2 = nn.Linear(self.mlp_dim, config.emb_dim, bias=config.use_bias, dtype=config.dtype)
        self.dropout = nn.Dropout(config.dropout_rate)

        if config.activation == "relu":
            self.activation = F.relu
        elif config.activation == "silu":
            self.activation = F.silu
        else:
            raise ValueError(f"Unsupported activation: {config.activation}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerLayer(nn.Module):
    """Transformer layer"""

    def __init__(self, config: TransformerLayerConfig):
        super().__init__()
        self.config = config
        assert config.emb_dim % config.num_heads == 0, "emb_dim must be divisible by num_heads"

        self.norm1 = nn.LayerNorm(config.emb_dim, eps=1e-6, elementwise_affine=config.use_bias, dtype=config.dtype)

        self.attn = nn.MultiheadAttention(
            embed_dim=config.emb_dim,
            num_heads=config.num_heads,
            dropout=config.attention_dropout_rate,
            bias=config.use_bias,
            batch_first=True,
            dtype=config.dtype
        )
        self.dropout = nn.Dropout(config.dropout_rate)

        self.norm2 = nn.LayerNorm(config.emb_dim, eps=1e-6, elementwise_affine=config.use_bias, dtype=config.dtype)
        self.mlp = MlpBlock(config=config)


    def forward(
        self,
        embeddings: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        config = self.config
        residual = embeddings
        base_dims = embeddings.shape[:-2]

        x = self.norm1(embeddings)
        if torch.isnan(x).any():
            raise ValueError("NaN after norm1")

        attn_mask_mha = None
        if pad_mask is not None:
             if pad_mask.ndim < 3:
                 raise ValueError(f"pad_mask ndim expected >= 3, got {pad_mask.ndim}")
             if pad_mask.ndim > 3:   
                pad_mask = pad_mask.squeeze(-3)
            
             attn_mask_mha = ~pad_mask
             attn_mask_mha = attn_mask_mha.unsqueeze(1)
             attn_mask_mha = attn_mask_mha.expand(-1, config.num_heads, -1, -1)
             attn_mask_mha = attn_mask_mha.reshape(-1, attn_mask_mha.size(-2), attn_mask_mha.size(-1))


        attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask, key_padding_mask=None, need_weights=False)
        if torch.isnan(attn_output).any():
            raise ValueError("NaN after attention")
        x = self.dropout(attn_output)
        if torch.isnan(x).any():
            raise ValueError("NaN after dropout")
        embeddings = embeddings + x
        if torch.isnan(embeddings).any():
            raise ValueError("NaN after residual connection 1")
        
        residual = embeddings
        x = self.norm2(embeddings)
        if torch.isnan(x).any():
            raise ValueError("NaN after norm2")
        mlp_output = self.mlp(x)
        if torch.isnan(mlp_output).any():
            raise ValueError("NaN after MLP")
        embeddings = embeddings + mlp_output
        if torch.isnan(embeddings).any():
            raise ValueError("NaN after residual connection 2")
        return embeddings

class DenoisingCrossAttentionTransformerLayer(nn.Module):
    """Transformer Decoder layer incorporating standard self-attention and
    denoising multihead cross-attention, designed to work with Nvib output.
    Follows Pre-LN structure and assumes batch_first=True convention.
    """

    def __init__(self, config: TransformerLayerConfig):
        super().__init__()
        self.config = config
        assert config.emb_dim % config.num_heads == 0, "emb_dim must be divisible by num_heads"

        self.norm1 = nn.LayerNorm(config.emb_dim, eps=1e-6, elementwise_affine=config.use_bias, dtype=config.dtype)
        self.self_attn = nn.MultiheadAttention(
            embed_dim=config.emb_dim,
            num_heads=config.num_heads,
            dropout=config.attention_dropout_rate,
            bias=config.use_bias,
            batch_first=True,
            dtype=config.dtype
        )
        self.dropout1 = nn.Dropout(config.dropout_rate)

        self.norm2 = nn.LayerNorm(config.emb_dim, eps=1e-6, elementwise_affine=config.use_bias, dtype=config.dtype)
        self.cross_attn_denoising = DenoisingMultiheadAttention(
            embed_dim=config.emb_dim,
            num_heads=config.num_heads,
            dropout=config.attention_dropout_rate,
            bias=config.use_bias,
            batch_first=True,
            kdim=config.emb_dim,
            vdim=config.emb_dim,
            dtype=config.dtype
        )
        self.dropout2 = nn.Dropout(config.dropout_rate)

        self.norm3 = nn.LayerNorm(config.emb_dim, eps=1e-6, elementwise_affine=config.use_bias, dtype=config.dtype)
        self.mlp = MlpBlock(config=config)

    def _make_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Generates a boolean attention mask with causal masking applied only to the final segment."""
        L = self.config.seq_len
        T = seq_len
        assert L <= T, "Causal segment must fit in sequence."

        attn_mask = torch.ones(T, T, dtype=torch.bool, device=device)

        head_len = T - L

        causal_block = torch.tril(torch.ones(L, L, dtype=torch.bool, device=device))
        attn_mask[-L:, -L:] = causal_block

        attn_mask[:head_len, -L:] = False

        return ~attn_mask

    def forward(
        self,
        decoder_embeddings: torch.Tensor,
        nvib_output: Dict[str, torch.Tensor],
        decoder_key_padding_mask: Optional[torch.Tensor] = None,
        causal_mask: Optional[torch.Tensor] = None,
        diffusion: Optional[int] = None
    ) -> torch.Tensor:
        config = self.config
        batch_size, target_seq_len, embed_dim = decoder_embeddings.shape
        x = decoder_embeddings
        x_norm = self.norm1(x)

        if causal_mask is None and not diffusion:
            causal_mask = self._make_causal_mask(target_seq_len, decoder_embeddings.device)
        else:
            causal_mask = None
        x_attn, _ = self.self_attn(
            query=x_norm,
            key=x_norm,
            value=x_norm,
            key_padding_mask=decoder_key_padding_mask,
            attn_mask=causal_mask,
            need_weights=False
        )
        x_attn = self.dropout1(x_attn)
        x = x + x_attn

        z_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] = nvib_output['z']
        nvib_padding_mask: torch.Tensor = nvib_output['memory_key_padding_mask']

        z, pi, mu, logvar = z_tuple

        if z.shape[1] == batch_size and z.shape[0] != batch_size:
            z = z.transpose(0, 1)
            pi = pi.transpose(0, 1)
            mu = mu.transpose(0, 1)
            logvar = logvar.transpose(0, 1)

        source_latent_seq_len = z.shape[1]
        assert z.shape == (batch_size, source_latent_seq_len, embed_dim), f"Shape mismatch for z: {z.shape}, expected: {batch_size, source_latent_seq_len, embed_dim}"
        assert pi.shape == (batch_size, source_latent_seq_len, 1), f"Shape mismatch for pi: {pi.shape}"
        assert mu.shape == (batch_size, source_latent_seq_len, embed_dim), f"Shape mismatch for mu: {mu.shape}"
        assert logvar.shape == (batch_size, source_latent_seq_len, embed_dim), f"Shape mismatch for logvar: {logvar.shape}"
        assert nvib_padding_mask.shape == (batch_size, source_latent_seq_len), f"Shape mismatch for mask: {nvib_padding_mask.shape}"

        cross_key = (z, pi, mu, logvar)
        cross_value = (z, pi, mu, logvar)
        x_norm = self.norm2(x)
        x_attn, _ = self.cross_attn_denoising(
            query=x_norm,
            key=cross_key,
            value=cross_value,
            key_padding_mask=nvib_padding_mask,
            attn_mask=None,
            need_weights=False
        )
        x_attn = self.dropout2(x_attn)
        x = x + x_attn

        x_norm = self.norm3(x)
        x_mlp = self.mlp(x_norm)
        decoder_embeddings = x + x_mlp

        return decoder_embeddings