import torch
import torch.nn as nn
import torch_geometric as pyg


class SGC(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        **kwargs,
    ):
        super().__init__()
        self.out_channels = out_channels
        # when no training nodes available initialize weights randomly
        weights = pyg.nn.dense.linear.reset_weight_(
            torch.empty(out_channels, in_channels), in_channels
        )
        self.weights = nn.Parameter(weights.T)

    def fit(self, x, edge_index, y, train_mask):
        adj = pyg.utils.to_dense_adj(edge_index, max_num_nodes=x.shape[0])[0]
        if adj.diagonal().sum() != x.shape[0]:
            adj = adj + torch.eye(x.shape[0], device=x.device)

        deg = adj.sum(1)
        deg_inv_sqrt = torch.diag(deg.pow(-0.5))
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
        lap = deg_inv_sqrt @ adj @ deg_inv_sqrt

        diff_x = lap @ lap @ x
        diff_x = diff_x[train_mask]

        self.weights = nn.Parameter(
            torch.linalg.inv(diff_x.T @ diff_x + 1 * torch.eye(diff_x.shape[1]))
            @ diff_x.T
            @ torch.eye(self.out_channels)[y[train_mask]]
        )

    def forward(self, x, edge_index, mask=None):
        # TODO: make it completely sparse
        adj = pyg.utils.to_dense_adj(edge_index, max_num_nodes=x.shape[0])[0]
        if (adj.diagonal() == 0).all():
            adj = adj + torch.eye(x.shape[0], device=x.device)
        deg = adj.sum(1)
        deg_inv_sqrt = torch.diag(deg.pow(-0.5))
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
        lap = deg_inv_sqrt @ adj @ deg_inv_sqrt

        diff_x = lap @ lap @ x
        if mask is not None:
            diff_x = diff_x[mask]

        out = diff_x @ self.weights
        return out
