import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
from torch_geometric.nn import global_mean_pool

# --- GCNLayer Definition ---
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(
            init.kaiming_uniform_(torch.empty(in_features, out_features), mode='fan_in', nonlinearity='relu')
        )

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, input, adj, h_0, lamda, alpha, l):
        h_l = torch.spmm(adj, input)
        features = (1 - alpha) * h_l + alpha * h_0
        n = self.weight.shape[0]
        I_n = torch.eye(n, device=input.device)
        beta = np.log((lamda / l) + 1)
        weights = (1 - beta) * I_n + beta * self.weight
        return torch.mm(features, weights)

# --- GCNII Model Definition ---
class GCNII(nn.Module):
    def __init__(self, dataset, num_layers, hidden, dropout=0.6, lamda=0.5, alpha=0.1):
        super(GCNII, self).__init__()
        self.pre_fc = nn.Linear(dataset.num_features, hidden)
        self.graph_convs = nn.ModuleList([GCNLayer(hidden, hidden) for _ in range(num_layers)])
        self.post_fc = nn.Linear(hidden, 1)
        self.relu = nn.ReLU()
        self.dropout = dropout
        self.lamda = lamda
        self.alpha = alpha

    def reset_parameters(self):
        self.pre_fc.reset_parameters()
        for conv in self.graph_convs:
            conv.reset_parameters()
        self.post_fc.reset_parameters()

    def forward(self, data):
        x, adj = data.x, data.norm_adj 
        x = F.dropout(x, self.dropout, training=self.training)
        h_0 = self.relu(self.pre_fc(x))
        h = h_0
        for i, conv in enumerate(self.graph_convs, start=1):
            h = F.dropout(h, self.dropout, training=self.training)
            h = self.relu(conv(h, adj, h_0, self.lamda, self.alpha, i))
        if hasattr(data, 'batch'):
            h = global_mean_pool(h, data.batch)
        h = F.dropout(h, self.dropout, training=self.training)
        h = self.post_fc(h)
        return h

    def __repr__(self):
        return self.__class__.__name__
