import torch.nn as nn
from graph_learning.module import ModuleConfig, get_module
from graph_learning.utils.misc import dict_merge_rec

@ModuleConfig.register('graph-dec-multi',
                       help='[Decoder] Paralleling multiple decoders.')
class GraDecMultiModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.layers = []
        for l in self.layer_vars:
            layer = get_module(context, l)
            self.layers.append(layer)

    @property
    def builder(self):
        return GraphDecoder

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--layers', nargs='+',
                            dest='layer_vars',
                            help='sub-decoders')

class GraphDecoder(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.decoders = nn.ModuleList(layers)

    def forward(self, data, hidden):
        outputs = {}
        losses = []

        losses_train = []
        losses_val = []
        for decoder in self.decoders:
            decoder_output = decoder(data, hidden)
            losses.append(decoder_output.pop('loss'))
            outputs = dict_merge_rec(outputs, decoder_output)
        outputs['loss'] = sum(losses)

        return outputs


