import dgl
import torch
import torch.nn.functional as F

def to_dense(g):
    """Convert the data object into dense node, feature, node label, and
    adjacency matrices.

    Returns
    -------
    X : torch.Tensor of shape (|V|, d)
        Node feature matrix with binary entries.
    Y : torch.Tensor of shape (|V|)
        Node label vector.
    E : torch.Tensor of shape (|V|, |V|)
        Adjacency matrix with binary entries.
    """
    # Recover the unnormalized binary feature matrix.
    X = g.ndata['feat']
    Y = g.ndata['label']
    g = dgl.remove_self_loop(g)
    num_nodes = g.num_nodes()
    E = torch.zeros(num_nodes, num_nodes)
    src, dst = g.edges()
    E[dst, src] = 1.

    return X, Y, E

def preprocess(g):
    """Prepare one-hot encodings of the node features, node labels, and edge
    types.

    Returns
    -------
    X_one_hot : torch.Tensor of shape (d, |V|, 2)
        X_one_hot[i, :, :] is the one-hot encoding of the i-th node feature.
    Y : torch.Tensor of shape (|V|)
        Node labels.
    E_one_hot : torch.Tensor of shape (|V|, |V|, 2)
        - E_one_hot[:, :, 0] indicates the absence of an edge,
          excluding the self-loops.
        - E_one_hot[:, :, 1] is the original adjacency matrix.
    X_marginal : torch.Tensor of shape (d, 2)
        X_marginal[i, :] is the marginal probabilities for the i-th node
        feature.
    Y_marginal : torch.Tensor of shape (C)
        Marginal probabilities for the node labels.
    E_marginal : torch.Tensor of shape (2)
        Marginal probabilities for the edge existence.
    X_cond_Y_marginals : torch.Tensor of shape (d, C, 2)
        X_cond_Y_marginals[i, c, :] is the marginal probabilities for the
        i-th node feature given node label c.
    """
    X, Y, E = to_dense(g)

    X_one_hot_list = []
    for i in range(X.size(1)):
        X_i_one_hot = F.one_hot(X[:, i].long())
        X_one_hot_list.append(X_i_one_hot)
    X_one_hot = torch.stack(X_one_hot_list, dim=0).float()

    Y_one_hot = F.one_hot(Y).float()
    E_one_hot = F.one_hot(E.long()).float()

    # Compute marginal probabilities.
    X_one_hot_count = X_one_hot.sum(dim=1)
    X_marginal = X_one_hot_count / X_one_hot_count.sum(dim=1, keepdim=True)

    # Compute marginal probabilities for X | Y.
    X_cond_Y_marginals = []
    num_classes = Y.max().item() + 1
    for y in range(num_classes):
        nodes_y = Y == y
        X_one_hot_y = X_one_hot[:, nodes_y]
        X_one_hot_count_y = X_one_hot_y.sum(dim=1)
        X_marginal_y = X_one_hot_count_y / X_one_hot_count_y.sum(dim=1, keepdim=True)
        X_cond_Y_marginals.append(X_marginal_y)
    X_cond_Y_marginals = torch.stack(X_cond_Y_marginals, dim=1)

    Y_one_hot_count = Y_one_hot.sum(dim=0)
    Y_marginal = Y_one_hot_count / Y_one_hot_count.sum()

    E_one_hot_count = E_one_hot.sum(dim=0).sum(dim=0)
    E_marginal = E_one_hot_count / E_one_hot_count.sum()

    return X_one_hot, Y, E_one_hot, X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals
