import torch
import torch.nn as nn
import torch.nn.functional as F

from graph_learning.module import Module, ModuleConfig, register_module, get_module
from graph_learning.utils import dict_merge_rec

@ModuleConfig.register('node-cls-decoder',
                       help='[Decoder] node classification.')
class NodeClassificationDecoderModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.device = context.global_.device

        self.classifier = get_module(context, self.classifier)
        self.loss_func = get_module(context, self.loss_func)

    @property
    def builder(self):
        return NodeClassificationDecoder

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--classifier',
                            help='classification layer, map hidden representation from graph encoder to logits.')
        parser.add_argument('--loss-func',
                            help='Loss layer.')

class NodeClassificationDecoder(nn.Module):
    def __init__(self, classifier,
                 loss_func,):
        super().__init__()

        self.classifier = classifier
        self.loss_func = loss_func

    def forward(self, data, hidden):
        if self.classifier is not None:
            logits = self.classifier(hidden)
        else:
            logits = hidden
        labels = data.node_labels()

        mask = data.node_mask()

        loss_outputs = self.loss_func(logits, labels, mask)
        ret = loss_outputs

        return ret
