import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCN2Conv, Linear
from .newConv import new2Conv

class GCN2(nn.Module):
    def __init__(self, num_features, hidden_features, num_classes, num_layers, alpha=0.1, theta=0.5,
                 shared_weights=True, dropout=0.0):
        super().__init__()

        self.pre = Linear(num_features, hidden_features)
        self.post = Linear(hidden_features, num_classes)

        self.convs = nn.ModuleList(GCN2Conv(hidden_features, alpha, theta, layer + 1, shared_weights) for layer in range(num_layers))
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, self.dropout, training=self.training)
        x = x_0 = self.pre(x).relu()

        for conv in self.convs:
            x = F.dropout(x, self.dropout, training=self.training)
            x = conv(x, x_0, edge_index, edge_weight)
            x = x.relu()

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.post(x)

        return x

class newGCN2(nn.Module):
    def __init__(self, num_features, hidden_features, num_classes, num_layers, alpha=0.1, theta=0.5,
                 shared_weights=True, dropout=0.0):
        super().__init__()
        K = 1
        self.pre = Linear(num_features, hidden_features)
        self.post = Linear(hidden_features, num_classes)

        self.convs = nn.ModuleList(new2Conv(K, hidden_features, alpha, theta, layer + 1, shared_weights) for layer in range(num_layers))
        self.dropout = dropout

    def forward(self, x, edge_index, inf_edge_index=None, edge_weight=None, inf_edge_weight=None):
        x = F.dropout(x, self.dropout, training=self.training)
        x = x_0 = self.pre(x).relu()

        for conv in self.convs:
            x = F.dropout(x, self.dropout, training=self.training)
            x = conv(x, x_0, inf_edge_index, edge_index, inf_edge_weight, edge_weight)
            x = x.relu()

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.post(x)

        return x