from torch import nn
from torch_geometric.nn.conv import GATConv, GCNConv, GINConv, SAGEConv
from torch_geometric.nn.models import GAT, GCN, GIN


def make_graph_conv(conv: str, d: int) -> nn.Module:
    """Create a graph convolutional layer.

    :param conv: Type of graph convolutional layer.
    :param d: Number of hidden channels.
    :return: Graph convolutional layer.
    """
    if conv == "GAT":
        return GATConv(d, d)
    elif conv == "GCN":
        return GCNConv(d, d)
    elif conv == "GIN":
        return GINConv(nn.Sequential(nn.Linear(d, d), nn.ReLU(), nn.Linear(d, d)))
    elif conv == "SAGE":
        return SAGEConv(d, d)
    else:
        raise ValueError(f"Unknown graph convolutional layer type: {conv}")


def make_gnn(conv: str, n_layers: int, d: int) -> nn.Module:
    """Create a graph neural network.

    :param conv: Type of graph convolutional layer.
    :param n_layers: Number of layers.
    :param d: Number of hidden channels.
    :return: Graph neural network.
    """
    if conv == "GAT":
        return GAT(d, d, n_layers, d)
    elif conv == "GCN":
        return GCN(d, d, n_layers, d)
    elif conv == "GIN":
        return GIN(d, d, n_layers, d)
    else:
        raise ValueError(f"Unknown graph neural network type: {conv}")
