import math

import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected


class SparseLinearN(nn.Module):
    def __init__(self, n: int, out_features: int):
        super().__init__()
        self.W = nn.Parameter(torch.empty(n, out_features))
        self.b = nn.Parameter(torch.empty(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        bound = 1.0 / math.sqrt(self.W.shape[0]) if self.W.shape[0] > 0 else 0.0
        nn.init.uniform_(self.W, -bound, bound)
        nn.init.uniform_(self.b, -bound, bound)

    def forward(self, edge_index: torch.Tensor, num_nodes: int, device: torch.device):
        values = torch.ones(edge_index.shape[1], device=device, dtype=self.W.dtype)
        A = torch.sparse_coo_tensor(
            edge_index.flip(0), values, (num_nodes, num_nodes), device=device
        )
        out = torch.sparse.mm(A, self.W) + self.b
        return out


class SparseAdjacencyMLP(nn.Module):
    def __init__(self, hidden_dim: int, num_layers: int, dropout: float = 0.5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.first_layers = nn.ModuleDict()
        post_layers = []
        for _ in range(max(0, num_layers - 1)):
            post_layers += [
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(p=dropout),
            ]
        self.post = nn.Sequential(*post_layers)

    def forward(self, edge_index: torch.Tensor, num_nodes: int, device: torch.device):
        key = str(num_nodes)
        if key not in self.first_layers:
            self.first_layers[key] = SparseLinearN(num_nodes, self.hidden_dim).to(
                device
            )
        else:
            self.first_layers[key] = self.first_layers[key].to(device)
        h = self.first_layers[key](edge_index, num_nodes, device)
        h = self.post(h)
        return h


class LINKXBackbone(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_dim: int,
        x_num_layers: int,
        a_num_layers: int,
        dropout: float = 0.5,
    ):
        super().__init__()
        layers = [nn.Linear(in_features, hidden_dim), nn.ReLU(), nn.Dropout(p=dropout)]
        for _ in range(max(0, x_num_layers - 1)):
            layers += [
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(p=dropout),
            ]
        self.mlp_x = nn.Sequential(*layers)
        self.mlp_a = SparseAdjacencyMLP(
            hidden_dim=hidden_dim, num_layers=a_num_layers, dropout=dropout
        )
        self.fuse = nn.Linear(2 * hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.embedding_dim = hidden_dim

    def forward(self, data: Data):
        X = data.x
        num_nodes = X.shape[0]
        undirected_edge_index = to_undirected(data.edge_index, num_nodes=num_nodes)
        h_x = self.mlp_x(X)
        h_a = self.mlp_a(undirected_edge_index, num_nodes, X.device)
        z = self.fuse(torch.cat([h_a, h_x], dim=-1)) + h_a + h_x
        z = self.relu(z)
        return z
