
from utils.constants import Cte
import torch_geometric.nn as geom_nn

import torch.nn as nn
def get_graph_layer(layer_name):
    if layer_name == Cte.GCN:
        return geom_nn.GCNConv
    elif layer_name == Cte.GAT:
        return geom_nn.GATConv
    elif layer_name == Cte.GC:
        return geom_nn.GraphConv
    elif layer_name == Cte.GIN:
        def my_layer(in_channels, out_channels, **kwargs):
            return geom_nn.GINConv(nn=nn.Linear(in_channels, out_channels),
                                   **kwargs)
        return my_layer
    elif layer_name == Cte.GINE:
        return geom_nn.GINEConv