import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
from torch_geometric.nn import global_mean_pool

# --- Helper Functions ---
def adjacency_matrix(edge_index, edge_attr):
    num_nodes = int(edge_index.max().item()) + 1
    adj = torch.zeros((num_nodes, num_nodes), device=edge_index.device)
    for i in range(edge_index.size(1)):
        u, v = edge_index[0, i], edge_index[1, i]
        adj[u, v] = edge_attr[i]
    return adj

def normalize_adjacency_matrix(adj):
    I = torch.eye(adj.size(0), device=adj.device)
    A_hat = adj + I
    D = torch.sum(A_hat, dim=1)
    D_inv_sqrt = torch.pow(D, -0.5)
    D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.
    D_mat_inv_sqrt = torch.diag(D_inv_sqrt)
    return D_mat_inv_sqrt @ A_hat @ D_mat_inv_sqrt

def sparse_matrix_to_torch_sparse_tensor(adj):
    indices = torch.nonzero(adj, as_tuple=False).t()
    values = adj[indices[0], indices[1]]
    return torch.sparse_coo_tensor(indices, values, size=adj.size())

# --- GCNLayer Definition ---
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(
            init.kaiming_uniform_(torch.empty(in_features, out_features), mode='fan_in', nonlinearity='relu')
        )

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, input, adj, h_0, lamda, alpha, l):
        h_l = torch.spmm(adj, input)
        features = (1 - alpha) * h_l + alpha * h_0
        n = self.weight.shape[0]
        I_n = torch.eye(n, device=input.device)
        beta = np.log((lamda / l) + 1)
        weights = (1 - beta) * I_n + beta * self.weight
        return torch.mm(features, weights)

# --- GCNII Model Definition ---
class GCNII(nn.Module):
    def __init__(self, dataset, num_layers, hidden, dropout=0.6, lamda=0.5, alpha=0.1):
        super(GCNII, self).__init__()
        self.pre_fc = nn.Linear(dataset.num_features, hidden)
        self.graph_convs = nn.ModuleList([GCNLayer(hidden, hidden) for _ in range(num_layers)])
        self.post_fc = nn.Linear(hidden, dataset.num_classes)
        self.relu = nn.ReLU()
        self.dropout = dropout
        self.lamda = lamda
        self.alpha = alpha

    def reset_parameters(self):
        self.pre_fc.reset_parameters()
        for conv in self.graph_convs:
            conv.reset_parameters()
        self.post_fc.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        if data.edge_attr is not None:
            edge_attr = data.edge_attr
        else:
            edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)
            
        adj = adjacency_matrix(edge_index, edge_attr)
        adj = normalize_adjacency_matrix(adj)
        adj = sparse_matrix_to_torch_sparse_tensor(adj)
        
        x = F.dropout(x, self.dropout, training=self.training)
        h_0 = self.relu(self.pre_fc(x))
        h = h_0
        for i, conv in enumerate(self.graph_convs, start=1):
            h = F.dropout(h, self.dropout, training=self.training)
            h = self.relu(conv(h, adj, h_0, self.lamda, self.alpha, i))
        if hasattr(data, 'batch'):
            h = global_mean_pool(h, data.batch)
        h = F.dropout(h, self.dropout, training=self.training)
        h = self.post_fc(h)
        return F.log_softmax(h, dim=1)

    def __repr__(self):
        return self.__class__.__name__
