from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, shared_weights = True):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.shared_weights = shared_weights
        if self.shared_weights == True:
            self.conv2 = GCNConv(hidden_channels, hidden_channels)
        else: 
             self.conv_layers = torch.nn.ModuleList([
            GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)
        ])
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.sc_lin = torch.nn.Linear(in_channels, hidden_channels)
        self.num_layers = num_layers 

    def forward(self, x, edge_index):
        sc = self.sc_lin(x)
        z = self.conv1(x, edge_index)
        z += sc
        for i in range(self.num_layers):
            if self.shared_weights == True:
                z_prev = z  # Store the value of z from the previous layer
                z = F.dropout(z, training=self.training)
                z = self.conv2(z, edge_index)
                z += z_prev 
            else: 
                z_prev = z  # Store the value of z from the previous layer
                z = F.dropout(z, training=self.training)
                z = self.conv_layers[i](z, edge_index)
                z += z_prev
             # Add z from the previous layer to the current layer
        z_prev = z
        z = F.dropout(z, training=self.training)
        z = self.conv3(z, edge_index)
        z += z_prev  # Add z from the previous layer to the last layer
        return z