from torch_geometric.nn import SGConv
import torch
from torch_geometric.nn import global_mean_pool

class SGCReg(torch.nn.Module):
    def __init__(self, num_node_features, num_layers):
        # super(SGCReg, self).__init__()
        super().__init__()
        self.conv = SGConv(num_node_features, 1, K=num_layers).double()

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv (x, edge_index, edge_attr)
        x = global_mean_pool(x, batch)
        return x.squeeze()


class LinearGC (torch.nn.Module):
    def __init__ (self, num_node_features, num_layers):
        super().__init__()
        self.theta = torch.nn.Parameter (torch.zeros(1, num_node_features), requires_grad=True)
        self.bias = torch.nn.Parameter (torch.tensor(0.), requires_grad=True)
        self.nlayers = num_layers

    def forward (self, xs, adjs):
        adjs_norm = adjs.clone().float()
        adjs_norm[:, torch.arange(adjs.shape[1]), torch.arange(adjs.shape[1])] = 1
        degs_norm = adjs_norm.sum(dim=2)
        degs_norm = ((degs_norm[:, :, None] @ degs_norm[:, None, :]))**0.5
        adjs_norm = adjs_norm / degs_norm
        adjs_norm_k = torch.matrix_power(adjs_norm, self.nlayers)
        return (adjs_norm_k @ xs @ self.theta[0][None, :, None])[:,:,0].mean(dim=1) + self.bias