import torch
from torch import nn


class QuantumBiasMultiheadAttention(nn.Module):
    """
    Multi-head self-attention module with structural quantum bias.
    Incorporates a learnable bias from CTQW-derived probability matrices.

    Args:
        dim (int): Input embedding dimension.
        heads (int): Number of attention heads.
    """
    def __init__(self, dim, heads=4):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5
        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(dim, dim)
        self.to_v = nn.Linear(dim, dim)
        self.fc_out = nn.Linear(dim, dim)

    def forward(self, x, mask=None, q_bias=None):
        """
        Compute attention with optional structural bias.

        Args:
            x (Tensor): Input features of shape [B, N, C].
            mask (Tensor or None): Optional attention mask [B, N].
            q_bias (Tensor or None): Structural bias matrix [B, N, N] derived from CTQW.

        Returns:
            Tensor: Output features after attention [B, N, C].
        """
        B, N, C = x.size()
        H = self.heads
        d_k = C // H

        # Project input to Q, K, V
        q = self.to_q(x).view(B, N, H, d_k).transpose(1, 2)  # [B, H, N, d_k]
        k = self.to_k(x).view(B, N, H, d_k).transpose(1, 2)
        v = self.to_v(x).view(B, N, H, d_k).transpose(1, 2)

        # Scaled dot-product attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, H, N, N]

        if q_bias is not None:
            # Expand quantum bias for multi-head use
            q_bias = q_bias.unsqueeze(1).expand(-1, H, -1, -1)  # [B, H, N, N]
            attn = attn + q_bias

        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))

        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)  # [B, N, C]

        return self.fc_out(out)


class TransformerEncoderLayerWithQuantumBias(nn.Module):
    """
    Transformer encoder layer with quantum structural bias.
    Incorporates CTQW-based bias into self-attention.

    Args:
        dim (int): Input embedding dimension.
        heads (int): Number of attention heads.
    """
    def __init__(self, dim, heads=4):
        super().__init__()
        self.attn = QuantumBiasMultiheadAttention(dim, heads)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Linear(4 * dim, dim)
        )
        # Optional: learnable scaling of bias
        # self.bias_scale = nn.Parameter(torch.tensor(1.0))

    def forward(self, x, mask=None, qw_probs=None):
        """
        Forward pass with quantum walk probability bias.

        Args:
            x (Tensor): Input node features [B, N, C].
            mask (Tensor or None): Optional attention mask [B, N].
            qw_probs (Tensor or None): CTQW evolution tensor [B, T, N, N].

        Returns:
            Tensor: Output features [B, N, C].
        """
        # Extract structural bias from the last time step
        if qw_probs is not None:
            q_bias = qw_probs[:, -1, :, :]  # [B, N, N]
            q_bias = torch.log1p(q_bias / (q_bias.sum(dim=-2, keepdim=True) + 1e-8))
            # Optional: apply learnable scaling
            # q_bias = self.bias_scale * q_bias
        else:
            q_bias = None

        x = x + self.attn(self.norm1(x), mask, q_bias)
        x = x + self.mlp(self.norm2(x))
        return x
