import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax, GATConv

from graph_learning.module import Module, ModuleConfig, register_module

import torch.nn.functional as F

from .multi_layer_mp import CommonMultiLayerMPConfig

class GATConvF(nn.Module):
    def __init__(self, gat_layer):
        super().__init__()
        self.gat_layer = gat_layer

    def forward(self, g, x):
        h = self.gat_layer(g, x)
        return h.flatten(1)

@ModuleConfig.register('gat')
class GATModuleConfig(CommonMultiLayerMPConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.set_when_empty('feat_drop', 0.2)
        self.set_when_empty('attn_drop', 0.2)
        self.set_when_empty('negative_slope', 0.2)

    def _input_layer_builder(self, in_size, out_size):
        return GATConvF(GATConv(
            in_size, out_size//self.heads, self.heads,
            self.feat_drop, self.attn_drop, self.negative_slope, self.residual,
            allow_zero_in_degree=True))

    def _hidden_layer_builder(self, in_size, out_size):
        return GATConvF(GATConv(
            in_size, out_size//self.heads, self.heads,
            self.feat_drop, self.attn_drop, self.negative_slope, self.residual,
            allow_zero_in_degree=True))

    def _output_layer_builder(self, in_size, out_size):
        return GATConvF(GATConv(
            in_size, out_size, 1,
            self.feat_drop, self.attn_drop, self.negative_slope, self.residual,
            allow_zero_in_degree=True))

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--feat-drop', type=float)
        parser.add_argument('--attn-drop', type=float)
        parser.add_argument('--negative-slope', type=float)
        parser.add_argument('--heads', type=int)
        parser.add_argument('--residual', action='store_true')
