import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SGConv
from .newConv import newConv, newSGConv

class GCN(nn.Module):
    def __init__(self, num_features, hidden_features, num_classes, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.gc1 = GCNConv(num_features, hidden_features)
        self.gc2 = GCNConv(hidden_features, num_classes)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.relu(self.gc1(x, edge_index, edge_weight))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, edge_index, edge_weight)
        return x

class GCNL(nn.Module):
    def __init__(self, num_features, hidden_features, num_classes):
        super().__init__()
        self.gc1 = GCNConv(num_features, hidden_features)
        self.ln1 = nn.Linear(num_features, hidden_features, bias=False)
        self.gc2 = GCNConv(hidden_features, hidden_features)
        self.ln2 = nn.Linear(hidden_features, hidden_features, bias=False)
        self.gc3 = GCNConv(hidden_features, num_classes)
        self.ln3 = nn.Linear(hidden_features, num_classes, bias=False)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.elu(self.gc1(x, edge_index, edge_weight) + self.ln1(x))
        x = F.elu(self.gc2(x, edge_index, edge_weight) + self.ln2(x))
        x = self.gc3(x, edge_index, edge_weight) + self.ln3(x)
        return x

class SGC(nn.Module):
    def __init__(self, num_features, K, num_classes, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.gc = SGConv(num_features, num_classes, K=K)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc(x, edge_index, edge_weight)
        return x

class newGCN(nn.Module):
    def __init__(self, num_features, hidden_features, num_classes, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.gc1 = newConv(num_features, hidden_features)
        self.gc2 = newConv(hidden_features, num_classes)

    def forward(self, x, edge_index, inf_edge_index=None, edge_weight=None, inf_edge_weight=None):
        x = F.relu(self.gc1(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight)
        return x

class newGCNL(nn.Module):
    def __init__(self, num_features, hidden_features, num_classes):
        super().__init__()
        self.gc1 = newConv(num_features, hidden_features)
        self.ln1 = nn.Linear(num_features, hidden_features, bias=False)
        self.gc2 = newConv(hidden_features, hidden_features)
        self.ln2 = nn.Linear(hidden_features, hidden_features, bias=False)
        self.gc3 = newConv(hidden_features, num_classes)
        self.ln3 = nn.Linear(hidden_features, num_classes, bias=False)

    def forward(self, x, edge_index, inf_edge_index=None, edge_weight=None, inf_edge_weight=None):
        x = F.elu(self.gc1(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight) + self.ln1(x))
        x = F.elu(self.gc2(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight) + self.ln2(x))
        x = self.gc3(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight) + self.ln3(x)
        return x

class newSGC(nn.Module):
    def __init__(self, num_features, K, num_classes, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.gc = newSGConv(num_features, num_classes, K=K)

    def forward(self, x, edge_index, inf_edge_index=None, edge_weight=None, inf_edge_weight=None):
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight)
        return x