#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : conv_block.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

__all__ = ["ConvBlock"]

from dgl import DGLGraph
from megraph.representation import MultiFeatures


class ConvBlock(nn.Module):
    def __init__(self, conv, norms=None, act=None, dropout=None):
        super(ConvBlock, self).__init__()
        self.conv = conv
        self.norms = norms
        self.act = act
        self.dropout = dropout

    def forward(self, graph: DGLGraph, features: MultiFeatures) -> MultiFeatures:
        features = self.conv(graph, features)
        return features.apply_fn(self.norms).apply_fn(self.act).apply_fn(self.dropout)

    def get_output_dims(self):
        return self.conv.get_output_dims()
