from typing import Callable, Optional

import torch 
from torch import nn
from torch.nn import LayerNorm, Dropout


from .multihead_attention import MultiheadAttention
from .utils import get_activation_fn

class GraphTransformerLayer(nn.Module):
    def __init__(
        self, 
        embed_dim: int = 768, 
        ffn_embed_dim: int = 3072, 
        num_attn_heads: int = 8, 
        dropout: float = 0.1,
        attn_dropout: float = 0.1,
        activation_dropout: float = 0.1, 
        activation_fn: str = "relu", 
        pre_layernorm: bool = False
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_attn_heads = num_attn_heads
        self.attn_dropout = attn_dropout
        self.pre_layernorm = pre_layernorm

        self.dropout_module = Dropout(dropout)
        self.activation_dropout_module = Dropout(activation_dropout)

        self.activation_fn = get_activation_fn(activation_fn)
        self.self_attn = MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_attn_heads,
            dropout=attn_dropout
        )

        self.self_attn_layer_norm = LayerNorm(embed_dim)
        self.fc1 = nn.Linear(embed_dim, ffn_embed_dim)
        self.fc2 = nn.Linear(ffn_embed_dim, embed_dim)
        self.final_layer_norm = LayerNorm(embed_dim)

    def forward(
        self, 
        x: torch.Tensor,
        self_attn_bias: Optional[torch.Tensor] = None, 
        self_attn_padding_mask: Optional[torch.Tensor] = None
    ):
        residual = x
        if self.pre_layernorm:
            x = self.self_attn_layer_norm(x)
        x, _ = self.self_attn(
            query=x, 
            key = x, 
            value = x, 
            attn_bias = self_attn_bias, 
            key_padding_mask = self_attn_padding_mask, 
            need_weights = False
        )
        x = self.dropout_module(x)
        x = residual + x

        x = self.self_attn_layer_norm(x)
        
        residual = x
        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)
        x = self.dropout_module(x)
        x = residual + x
        if not self.pre_layernorm:
            x = self.final_layer_norm(x)
        return x


        
    

