import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, BatchNorm

class GCN3(torch.nn.Module):
    def __init__(self, num_node_features, num_classes,hidden_channels,hidden_channels2,hidden_channels3):
        super(GCN3, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels2)  
        self.conv3 = GCNConv(hidden_channels2, hidden_channels3) 
        self.conv4 = GCNConv(hidden_channels3, num_classes) 
        
        self.norm1 = BatchNorm(hidden_channels)
        self.norm2 = BatchNorm(hidden_channels2)
        self.norm3 = BatchNorm(hidden_channels3)

    def forward(self, data, p=[0,0,0]):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=p[0],training=self.training)
        
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)  
        x = F.dropout(x, p=p[1],training=self.training)  
        
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)  
        x = F.dropout(x, p=p[2],training=self.training) 
        
        x = self.conv4(x, edge_index, edge_weight)  
        return x