import torch
import torch.nn as nn
from models.model_utils import index_select_ND
from typing import Tuple


class FFNEncoder(nn.Module):
    def __init__(self, args, input_size: int, node_fdim: int):
        super().__init__()
        self.args = args

        self.h_size = args.encoder_hidden_size
        self.depth = args.encoder_num_layers
        self.input_size = input_size
        self.node_fdim = node_fdim

        self.W_o = nn.Sequential(nn.Linear(self.input_size, self.h_size), nn.GELU())

    def forward(self, fnode: torch.Tensor, fmess: torch.Tensor,
                agraph: torch.Tensor, bgraph: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """Forward pass of the MPNEncoder.

        Parameters
        ----------
            fnode: torch.Tensor, node feature tensor
            fmess: torch.Tensor, message features
            agraph: torch.Tensor, neighborhood of an atom
            bgraph: torch.Tensor, neighborhood of a bond,
                except the directed bond from the destination node to the source node
            mask: torch.Tensor, masks on nodes
        """
        nei_message = index_select_ND(fmess, 0, agraph)
        node_hiddens = nei_message.sum(dim=1)
        node_hiddens = self.W_o(node_hiddens)

        if mask is None:
            mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device)
            mask[0, 0] = 0      # first node is padding

        return node_hiddens * mask, nei_message
