import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
from config import args

class GCN(nn.Module):
    def __init__(
        self, g, in_feats, n_classes, n_hidden, n_layers, dropout=0.5
    ):
        # 输入：g:图  in_feats:输入特征维度  n_classed:输出维度（类别数）  n_hidden:隐层维度  n_layers:网络层数
        super(GCN, self).__init__()
        self.g = g
        self.n_layers = n_layers
        self.layers = nn.ModuleList()
        # 输入层, 这里activation设为None是为了在forward中可以得到中间的logits作为中间特征
        # 在forward中保存每一层的logits后再通过激活函数
        self.layers.append(GraphConv(in_feats, n_hidden, activation=F.relu, allow_zero_in_degree=True))
        # 隐层
        for i in range(n_layers - 1):
            self.layers.append(
                GraphConv(n_hidden, n_hidden, activation=F.relu, allow_zero_in_degree=True)
            )
        # 输出层
        self.layers.append(GraphConv(n_hidden, n_classes, allow_zero_in_degree=True))
        # dropout：一个元素被置0的概率，这里先默认设为0
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, features, middle=False):
        # 输入：features:当前所有节点特征  middle:是否返回节点中间层特征
        h = features
        middle_feats = []

        # range(self.n_layers) 通过输入层和中间层，未通过输出层
        for i in range(self.n_layers):
            h = self.layers[i](self.g.to(args.device), h)
            # 将中间层产生的logits放入middle_feats列表中
            middle_feats.append(h)
            # 然后再通过激活函数
            h = F.relu(h)
        # 输出层未通过激活函数的logits，使用这个logits作为软标签相当于将温度设置到无穷大
        logits = self.layers[-1](self.g.to(args.device), h)
        # 如果middle为True，返回最终的logits 中间各层输出的logits
        if middle:
            return logits, middle_feats
        return logits
