import torch
import torch.nn.functional as F

from torch.nn import Module
from torch_geometric.nn.conv import GCNConv
from typing import TypeVar


# Create TypeVars for FloatTensor and LongTensor
FloatTensor = TypeVar('FloatTensor', torch.FloatTensor, torch.cuda.FloatTensor)
LongTensor = TypeVar('LongTensor', torch.LongTensor, torch.cuda.LongTensor)


class GCN(Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv_1 = GCNConv(in_channels, hidden_channels)
        self.conv_2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x: FloatTensor, edge_index: LongTensor) -> FloatTensor:
        x = F.relu(self.conv_1(x, edge_index))
        x = self.conv_2(x, edge_index)
        return x
