from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.nn import TransformerConv


class Transformer(nn.Module):
    """Transformer implemented with PyG framework."""

    def __init__(self, n_layers: int, d_hidden: int, d_edge: int, n_head: int, drop_prob: float=0.0) -> None:
        """Initialize transformer.

        :param n_layers: Number of layers.
        :param d_hidden: Hidden dimension.
        :param d_edge: Edge dimension.
        :param n_head: Number of attention heads.
        :param drop_prob: Dropout probability. Defaults to 0.0.
        """
        super().__init__()

        layers = [
            TransformerLayer(d_hidden, d_edge, n_head, drop_prob) for _ in range(n_layers-1)
        ]
        layers += [LastTransformerLayer(d_hidden, d_edge, n_head, drop_prob)]
        self.layers = nn.ModuleList(layers)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Perform a forward pass with the residual graph neural network.

        :param x: Tensor containing node features.
        :param edge_index: Tensor containing edge information.
        :param edge_attr: Tensor containing edge attributes.
        :return: Updated node features.
        """
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
        return x


class TransformerLayer(nn.Module):
    """Transformer layer used by the transformer."""

    def __init__(
        self, d_hidden: int, d_edge: int, n_head: int, drop_prob: float = 0.0
    ) -> None:
        """Initialize the transformer layer.

        :param d_hidden: Number of hidden channels.
        :param d_edge: Number of edge channels.
        :param n_head: Number of attention heads.
        :param drop_prob: Dropout probability. Defaults to 0.0.
        """
        super().__init__()

        self.attention = TransformerConv(
            in_channels=d_hidden,
            out_channels=d_hidden // n_head,
            heads=n_head,
            beta=True,
            dropout=drop_prob,
            edge_dim=d_edge,
            concat=True
        )
        self.layer_norm = nn.LayerNorm(d_hidden)
        self.relu = nn.ReLU()

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Forward pass with Transformer Layer.

        :param x: Tensor containing node features.
        :param edge_index: Tensor containing edge information.
        :param edge_attr: Tensor containing edge attributes.
        :return: Updated node features.
        """
        x = self.attention(x=x, edge_index=edge_index, edge_attr=edge_attr)
        x = self.layer_norm(x)
        x = self.relu(x)
        return x


class LastTransformerLayer(nn.Module):
    """Last transformer layer used by the transformer."""

    def __init__(
        self, d_hidden: int, d_edge: int, n_head: int, drop_prob: float = 0.0
    ) -> None:
        """Initialize the transformer layer.

        :param d_hidden: Number of hidden channels.
        :param d_edge: Number of edge channels.
        :param n_head: Number of attention heads.
        :param drop_prob: Dropout probability. Defaults to 0.0.
        """
        super().__init__()

        self.attention = TransformerConv(
            in_channels=d_hidden,
            out_channels=d_hidden,
            heads=n_head,
            beta=True,
            edge_dim=d_edge,
            dropout=drop_prob,
            concat=False
        )

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Forward pass with Transformer Layer.

        :param x: Tensor containing node features.
        :param edge_index: Tensor containing edge information.
        :param edge_attr: Tensor containing edge attributes.
        :return: Updated node features.
        """
        return self.attention(x=x, edge_index=edge_index, edge_attr=edge_attr)
