import torch
import torch.nn as nn
from torch_scatter import scatter_add
from models.utils import EntangledLinearLayer
from torch_geometric.utils import degree, to_dense_adj


class GCNLayer(nn.Module):
    def __init__(self, feat_hidden, pos_hidden, feat_out, pos_out, state_type, ent_deg) -> None:
        super().__init__()

        if ent_deg == 0:
            state_in = feat_hidden + pos_hidden if state_type == 'concat' else feat_hidden * pos_hidden
            state_out = feat_out + pos_out if state_type == 'concat' else feat_out * pos_out

            self.update = nn.Sequential(
                nn.Linear(state_in, state_out, False),
                nn.ReLU()
            )
        else:
            self.update = nn.Sequential(
                EntangledLinearLayer(feat_hidden, pos_hidden, feat_out, pos_out, ent_deg),
                nn.ReLU()
            )

    def forward(self, state, edge_index) -> torch.Tensor:
        send, rec = edge_index
        state_send, state_rec = state[send], state[rec]
        degrees = degree(edge_index[0])
        deg_comp = 1 / torch.sqrt(degrees[send] * degrees[rec])
        deg_comp = deg_comp[:, None] if len(state.shape) == 2 else deg_comp[:, None, None]
        aggr = scatter_add(deg_comp * state_send, rec, dim=0)
        return self.update(aggr)
