import torch
import torch.nn as nn

from graph_learning.module import Module, ModuleConfig, get_module
from graph_learning.utils import dict_merge_rec

@ModuleConfig.register('encoder-decoder',
                       help="""
                       [Framework] encoder-decoder framework.
                       Encoder: graph encoder, return node representations.
                       Decoder: down-stream task, compute loss from node representations.
                       """)
class EncDecModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.encoder = get_module(context, self.encoder)
        self.decoder = get_module(context, self.decoder)

        self.device = context.global_.device

    @property
    def builder(self):
        return EncoderDecoder

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--encoder',
                            help='encoder layer')
        parser.add_argument('--decoder',
                            help='decoder layer')

class EncoderDecoder(Module):
    def __init__(self, encoder,
                 decoder,

                 device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

        self.device = device

    def forward(self, data):
        feature = data.node_feature()
        encoder_outputs = self.encoder(data, feature)
        hidden = encoder_outputs['hidden']
        decoder_outputs = self.decoder(data, hidden)

        ret = decoder_outputs
        ret['outputs'].update({'hidden': hidden})
        ret['outputs'] = dict_merge_rec(
                ret['outputs'], encoder_outputs['outputs'])

        return ret
