import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import (new_layer_config,
                                                   BatchNorm1dNode)


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """
    def __init__(self, dim_in):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in

        # Encode integer node features via nn.Embeddings
        self.node_encoder = register.node_encoder_dict[cfg.dataset.node_encoder_name](cfg.gnn.dim_inner)
        if cfg.dataset.node_encoder_bn:
            self.node_encoder_bn = BatchNorm1dNode(
                new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
                                    has_bias=False, cfg=cfg))
        # Update dim_in to reflect the new dimension fo the node features
        self.dim_in = cfg.gnn.dim_inner
        cfg.gnn.dim_edge = cfg.gnn.dim_inner

        # Encode integer edge features via nn.Embeddings
        self.edge_encoder = register.edge_encoder_dict[cfg.dataset.edge_encoder_name](cfg.gnn.dim_edge)
        if cfg.dataset.edge_encoder_bn:
            self.edge_encoder_bn = BatchNorm1dNode(
                new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False,
                                    has_bias=False, cfg=cfg))

    def forward(self, batch):
        for module in self.children():
            batch = module(batch)
        return batch
