import torch
import torch.nn as nn
from torch_scatter import scatter_add
from models.utils import EntangledLinearLayer


class GINLayer(nn.Module):
    def __init__(self, feat_hidden, pos_hidden, state_type, ent_deg) -> None:
        super().__init__()

        if ent_deg == 0:
            state_size = feat_hidden + pos_hidden if state_type == 'concat' else feat_hidden * pos_hidden
            self.update = nn.Sequential(
                nn.Linear(state_size, state_size, False),
                nn.ReLU(),
                nn.Linear(state_size, state_size, False)
            )
        else:
            self.update = nn.Sequential(
                EntangledLinearLayer(feat_hidden, pos_hidden, feat_hidden, pos_hidden, ent_deg),
                nn.ReLU(),
                EntangledLinearLayer(feat_hidden, pos_hidden, feat_hidden, pos_hidden, ent_deg)
            )

    def forward(self, state, edge_index) -> torch.Tensor:
        send, rec = edge_index
        state_send, state_rec = state[send], state[rec]
        aggr = scatter_add(state_send, rec, dim=0)
        return self.update(state + aggr)
