import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor, nn

from src.utils.gnn_utils import make_graph_conv


class ResidualGNN(nn.Module):
    """Residual graph neural network."""

    def __init__(self, conv: str, n_blocks: int, d_hidden: int, d_edge: int) -> None:
        """Initialize the residual graph neural network.

        :param conv: Convolutional layer to use.
        :param n_blocks: Number of residual blocks.
        :param d_hidden: Number of hidden channels.
        :param d_edge: Number of edge features.
        """
        super().__init__()

        residual_blocks = [ResidualBlock(conv, d_hidden) for _ in range(n_blocks)]
        self.residual_blocks = nn.ModuleList(residual_blocks)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Perform a forward pass with the residual graph neural network.

        :param data: PyG batch object.
        """
        for residual_block in self.residual_blocks:
            x = residual_block(x, edge_index, edge_attr)
        return x


class ResidualBlock(nn.Module):
    """Residual block used by the residual graph neural network."""

    def __init__(self, conv: str, d: int) -> None:
        """Initialize the Residual Block.

        :param conv: Convolutional layer to use.
        :param d: Number of hidden channels.
        """
        super().__init__()

        self.conv_1 = make_graph_conv(conv, d)
        self.conv_rev_1 = make_graph_conv(conv, d)
        self.bn_1 = nn.BatchNorm1d(d)

        self.conv_2 = make_graph_conv(conv, d)
        self.conv_rev_2 = make_graph_conv(conv, d)
        self.bn_2 = nn.BatchNorm1d(d)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Forward pass with Residual Block.

        :param x: Tensor containing node features.
        :param edge_index: Tensor containing edges.
        :param edge_attr: Tensor containing edge weights.
        :return: Updated node features.
        """
        edge_index_rev = edge_index[[1, 0], :]
        _x = x
        out = self.conv_1(x, edge_index, edge_attr)
        out_rev = self.conv_rev_1(x, edge_index_rev, edge_attr)
        x = F.leaky_relu(self.bn_1(out + out_rev))

        out = self.conv_2(x, edge_index, edge_attr)
        out_rev = self.conv_rev_2(x, edge_index_rev, edge_attr)
        x = self.bn_2(out + out_rev)
        x = F.leaky_relu(x + _x)
        return x
