import torch.nn as nn
import torch
import torch.nn.functional as F
from torch_geometric.nn import SGConv
from torch.nn import Linear
import numpy as np
from torch.nn import Sequential



class SGC_LP(torch.nn.Module):
    def __init__(self, args, in_channels, hidden_channels, out_channels):
        super(SGC_LP, self).__init__()

        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")

        self.layers = SGConv(in_channels, hidden_channels, K=args.K)
        self.dropout = args.dropout

        if args.rest_param:
            self.reset_parameter()

    def reset_parameter(self):
        # 初始化 SGConv 层的参数
        nn.init.xavier_uniform_(self.layers.lin.weight.data)
        if self.layers.lin.bias is not None:
            self.layers.lin.bias.data.zero_()

    def forward(self, x, adj_t):
        x = self.layers(x, adj_t)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x
