import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import create_state, state_size
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn import global_mean_pool, MessagePassing, GINConv
from typing import Tuple
from torch_scatter import scatter_mean
from models.utils import EntangledLinearLayer


class SageConvLayer(nn.Module):
    def __init__(self, feat_hidden, pos_hidden, state_type, ent_deg) -> None:
        super().__init__()

        if ent_deg == 0:
            state_size = feat_hidden + pos_hidden if state_type == 'concat' else feat_hidden * pos_hidden
            self.update_1 = nn.Sequential(
                nn.Linear(state_size, state_size, False),
                nn.ReLU()
            )
            self.update_2 = nn.Sequential(
                nn.Linear(state_size, state_size, False),
                nn.ReLU()
            )
        else:
            self.update_1 = nn.Sequential(
                EntangledLinearLayer(feat_hidden, pos_hidden, feat_hidden, pos_hidden, ent_deg),
                nn.ReLU()
            )
            self.update_2 = nn.Sequential(
                EntangledLinearLayer(feat_hidden, pos_hidden, feat_hidden, pos_hidden, ent_deg),
                nn.ReLU()
            )

    def forward(self, state, edge_index) -> torch.Tensor:
        send, rec = edge_index
        state_send, state_rec = state[send], state[rec]
        aggr = scatter_mean(state_send, rec, dim=0)

        return self.update_1(state) + self.update_2(aggr)
