import torch.nn
import torch_geometric

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import AddSelfLoops, NormalizeFeatures
from torch_geometric.nn import GCNConv,GNNFF, aggr,SGConv,GCN2Conv



class SGCN(torch.nn.Module):
    def __init__(self,num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.num_features_modi = int(num_features)  #  int(num_features/2)  int(num_features/4)
        self.lin = torch.nn.Linear(self.num_features, self.num_features_modi)
        self.conv1 = SGConv(self.num_features_modi, 16,  name= 'conv1')
        self.conv2 = SGConv(16, 16)
        self.conv3 = SGConv(16, 16)
        self.conv4 = SGConv(16, 16)
        self.conv5 = SGConv(16, 16)
        self.conv6 = SGConv(16, 16)
        self.conv7 = SGConv(16, 16)
        self.conv8 = SGConv(16, 16)
        self.conv9 = SGConv(16, 16)
        self.conv10 = SGConv(16, 16)
        self.conv11 = SGConv(16, 16)
        self.conv12 = SGConv(16, 16)


        self.out = SGConv(16, self.num_classes)

    def forward(self, x , edge_index):
        # x = self.lin(x)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()

        x = self.out(x, edge_index)

        return  x.log_softmax(dim=-1)





class GCNII(torch.nn.Module):
    def __init__(self,num_features, num_classes):
        super().__init__()

        self.num_features = num_features
        self.num_classes = num_classes
        self.num_features_modi = int(num_features)  #  int(num_features/2)  int(num_features/4)
        self.hidden = 64  # 64
        self.alpha = 0.5

        self.linfea = torch.nn.Linear(self.num_features, self.hidden)
        self.conv1 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1,  name= 'conv1'  )
        self.conv2 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1 )
        self.conv3 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv4 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv5 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv6 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv7 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv8 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv9 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv10 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv11 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)
        self.conv12 = GCN2Conv(self.hidden, alpha=self.alpha, theta= 0.1)



        self.out = torch.nn.Linear(self.hidden, self.num_classes)

    def forward(self, x , edge_index):
        # x = self.lin(x)
        feat =self.linfea(x)

        res = feat
        res = self.conv1(res,feat,edge_index )
        res = self.conv2(res, feat,edge_index)


        out = self.out(res)

        return   out.log_softmax(dim=-1)




class GCN(torch.nn.Module):
    def __init__(self,num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.num_features_modi = int(num_features)  #  int(num_features/2)  int(num_features/4)
        self.lin = torch.nn.Linear(self.num_features, self.num_features_modi)
        self.conv1 = GCNConv(self.num_features_modi, 16,  name= 'conv1')
        self.conv2 = GCNConv(16, 16, bias=True)
        self.conv3 = GCNConv(16, 16,  bias=True )
        self.conv4 = GCNConv(16, 16,  bias=True)
        self.conv5 = GCNConv(16, 16, bias=True)
        self.conv6 = GCNConv(16, 16,  bias=True)
        self.conv7 = GCNConv(16, 16, bias=True)
        self.conv8 = GCNConv(16, 16,  bias=True)
        self.conv9 = GCNConv(16, 16,  bias=True)
        self.conv10 = GCNConv(16, 16,  bias=True)
        self.conv11 = GCNConv(16, 16,  bias=True)
        self.conv12 = GCNConv(16, 16,  bias=True)



        self.out = GCNConv(16, self.num_classes)

    def forward(self, x , edge_index):
        x = self.lin(x)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()


        x = self.out(x, edge_index)

        return  x.log_softmax(dim=-1)




class GCN_Res(torch.nn.Module):
    def __init__(self,num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.num_features_modi = int(num_features)  #  int(num_features/2)  int(num_features/4)
        self.lin = torch.nn.Linear(self.num_features, self.num_features_modi)

        self.conv1 = GCNConv(self.num_features_modi, 16)
        self.conv2 = GCNConv(16, 16)
        self.conv3 = GCNConv(16, 16)
        self.conv4 = GCNConv(16, 16)
        self.conv5 = GCNConv(16, 16)
        self.conv6 = GCNConv(16, 16)
        self.conv7 = GCNConv(16, 16)
        self.conv8 = GCNConv(16, 16)
        self.conv9 = GCNConv(16, 16)
        self.conv10 = GCNConv(16, 16)
        self.conv11 = GCNConv(16, 16)
        self.conv12 = GCNConv(16, 16)
        self.out = GCNConv(16, self.num_classes)

    def forward(self, x , edge_index):
        # x = self.lin(x)
        x1 = self.conv1(x, edge_index)
        x1_update = x1.relu() + x1

        x2 = self.conv2(x1_update, edge_index)
        x2_update = x2.relu() + x1_update



        out = self.out(x2_update, edge_index)

        return out.log_softmax(dim=-1)