"""
[Predict then Propagate: Graph Neural Networks meet Personalized PageRank]
(https://arxiv.org/abs/1810.05997)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.mock_sparse import create_from_coo, diag, identity
from torch.optim import Adam


class APPNP(nn.Module):
    def __init__(
        self,
        in_size,
        out_size,
        hidden_size=64,
        dropout=0.1,
        num_hops=10,
        alpha=0.1,
    ):
        super().__init__()

        self.f_theta = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, out_size),
        )
        self.num_hops = num_hops
        self.A_dropout = nn.Dropout(dropout)
        self.alpha = alpha

    def forward(self, A_hat, X):
        A_val_0 = A_hat.val
        Z_0 = Z = self.f_theta(X)
        for _ in range(self.num_hops):
            A_hat.val = self.A_dropout(A_val_0)
            Z = (1 - self.alpha) * A_hat @ Z + self.alpha * Z_0
        # Reset A_hat.val to avoid value corruption.
        A_hat.val = A_val_0
        return Z


def evaluate(g, pred):
    label = g.ndata["label"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]

    # Compute accuracy on validation/test set.
    val_acc = (pred[val_mask] == label[val_mask]).float().mean()
    test_acc = (pred[test_mask] == label[test_mask]).float().mean()
    return val_acc, test_acc


def train(model, g, A_hat, X):
    label = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

    for epoch in range(50):
        # Forward.
        model.train()
        logits = model(A_hat, X)

        # Compute loss with nodes in training set.
        loss = F.cross_entropy(logits[train_mask], label[train_mask])

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute prediction.
        model.eval()
        logits = model(A_hat, X)
        pred = logits.argmax(dim=1)

        # Evaluate the prediction.
        val_acc, test_acc = evaluate(g, pred)
        print(
            f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test"
            f" acc: {test_acc:.3f}"
        )


if __name__ == "__main__":
    # If CUDA is available, use GPU to accelerate the training, use CPU
    # otherwise.
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load graph from the existing dataset.
    dataset = CoraGraphDataset()
    g = dataset[0].to(dev)

    # Create the sparse adjacency matrix A.
    src, dst = g.edges()
    N = g.num_nodes()
    A = create_from_coo(dst, src, shape=(N, N))

    # Calculate the symmetrically normalized adjacency matrix.
    I = identity(A.shape, device=dev)
    A_hat = A + I
    D_hat = diag(A_hat.sum(dim=1)) ** -0.5
    A_hat = D_hat @ A_hat @ D_hat

    # Create APPNP model.
    X = g.ndata["feat"]
    in_size = X.shape[1]
    out_size = dataset.num_classes
    model = APPNP(in_size, out_size).to(dev)

    # Kick off training.
    train(model, g, A_hat, X)
