import torch.nn as nn
from graph_learning.module import ModuleConfig
from graph_learning.utils import dict_merge_rec

@ModuleConfig.register('graph-enc-seq',
                       help='[Encoder] encoder pipeline.')
class GraEncSequenceModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.layers = []
        for l in self.layer_vars:
            layer = getattr(context.module, l)
            self.layers.append(layer)

    @property
    def builder(self):
        return GraphEncoder

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

        parser.add_argument('--layers', '-l', nargs='+',
                            dest='layer_vars',
                            help='sub-encoders')

class GraphEncoder(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, graph, hidden):
        outputs = {}
        for layer in self.layers:
            layer_outputs = layer(graph, hidden)
            hidden = outputs['hidden']
            outputs = dict_merge_rec(outputs, outputs['outputs'])
        return {'hidden': hidden,
                'outputs': outputs}
