import torch
import torch.nn as nn
from dgl.nn.pytorch.conv import GraphConv

from graph_learning.module import ModuleConfig

from .multi_layer_mp import CommonMultiLayerMPConfig

@ModuleConfig.register('gcn',
                       help='[Encoder] GCN')
class GCNModuleConfig(CommonMultiLayerMPConfig):
    def _layer_builder(self, in_size, out_size):
        return GraphConv(
            in_size, out_size,
            norm=self.norm,)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--norm', choices=[
            'none', 'both', 'right'
        ])
