import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class Propagation(nn.Module):
    def __init__(self, dim, n_nodes, n_layers, dropout, activation, depth):
        super().__init__()

        self.dim = dim
        self.n_nodes = n_nodes
        self.n_layers = n_layers
        self.depth = depth
        self.activation = activation

        self.w = nn.ModuleList()
        self.ln = nn.ModuleList()
        for l in range(self.depth):
            self.w.append(nn.Linear(2 * self.dim, self.dim))
            if l < self.depth - 1:
                self.ln.append(nn.LayerNorm(self.dim))

        self.dropout = nn.Dropout(dropout)
        self.eye = torch.eye(n_nodes).bool()

        self.norm = nn.LayerNorm(self.dim)

    def forward(self, x, graph):
        # (..., N, D), (..., N, N)
        eye = self.eye.reshape((1,) * (graph.ndim - 2) + (*self.eye.shape,)).to(x.device)
        graph = graph * (~eye)  # Drop self-connection
        graph = graph / (torch.sum(graph, dim=-1, keepdim=True) + 1e-7)  # (..., N, N)
        graph = self.dropout(graph)

        for l in range(self.depth):
            neighbor = torch.matmul(graph, x)  # (..., N, N)
            x = self.w[l](torch.cat([x, neighbor], dim=-1))
            if l < self.depth - 1:
                x = self.ln[l](x)

        if self.activation == 'relu':
            x = torch.relu(x)
        elif self.activation == 'tanh':
            x = torch.tanh(x)
        elif self.activation == 'leaky':
            x = F.leaky_relu(x)

        return x
