import torch
import torch.nn as nn

class SelfAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        # x: (seq, batch, dim)
        x2, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + x2
        x = x + self.mlp(self.norm2(x))
        return x

class PReLU(nn.PReLU):
    """
    Custom PReLU class that allows for a different number of parameters for each neuron.
    """
    def __init__(self, num_parameters=1, init=0.25):
        super().__init__(num_parameters=num_parameters, init=init)

    def forward(self, input):
        #print(f"Input shape: {input.shape}")
        length = input.shape[0]
        batch = input.shape[1]
        input = input.view(length * batch, -1)  # Flatten the input
        input = super().forward(input)
        input = input.view(length, batch, -1)  # Reshape back to original dimensions
        return input
    

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim_q, dim_kv, num_heads, mlp_dim, dropout=0.0):
        super().__init__()
        self.norm_q  = nn.LayerNorm(dim_q)
        self.norm_kv = nn.LayerNorm(dim_kv)
        self.cross_attn = nn.MultiheadAttention(dim_q, num_heads, dropout=dropout)
        self.norm2   = nn.LayerNorm(dim_q)
        self.mlp     = nn.Sequential(
            nn.Linear(dim_q, mlp_dim),
            PReLU(mlp_dim),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim_q),
            nn.Dropout(dropout),
        )
    def forward(self, q, kv):
        # q:  (seq_q,  batch, dim_q)
        # kv: (seq_kv, batch, dim_kv)
        q2, _ = self.cross_attn(self.norm_q(q), self.norm_kv(kv), self.norm_kv(kv))
        x = q + q2
        x = x + self.mlp(self.norm2(x))
        return x

class ViT():
    pass

