import torch.nn as nn
import torch.nn.functional as F
from layers import *
import torch
import numpy as np
from torch_geometric.nn import GCNConv, SAGEConv, SGConv, GATConv
from torch.nn import Linear
from torch_geometric.nn import MessagePassing, APPNP
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch.nn import Parameter

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)

class GCN1(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN1, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        embed = x
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1),embed

class GCN1_lp(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout,cnt_train):
        super(GCN1_lp, self).__init__()

        self.gc1 = GraphConvolution1(nfeat, nhid)
        self.gc2 = GraphConvolution1(nhid, nclass)
        self.sim_train = torch.zeros(cnt_train)
        # self.gc1.requires_grad = True
        # self.gc2.requires_grad = True
        self.sim_train = torch.autograd.Variable(self.sim_train, requires_grad=True)
        # self.gc2 = torch.autograd.Variable(self.gc2, requires_grad=True)
        self.dropout = dropout
        # self.loss_train=

    def forward(self, x, adj,train_edge,train_label, m,criterion):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        # sim = []
        embed=x
        # n=torch.from_numpy(np.array(0))
        # # print(n)
        ed=train_edge

        nd1 = ed[0].long()
        nd2 = ed[1].long()

        mul_a = embed[nd1]
        mul_b = embed[nd2]

        sim0 = torch.dot(mul_a, mul_b)
        self.sim_train = (m(sim0))
            # print(n.item())

            # self.sim_train[n]=(m(sim0))
            # n+=1

            # sim.append((mul_a*mul_b).detach().numpy())
            # print(sim)
            # print(sim.type())
            # exit()
        # print(output[idx_train])
        # sim = torch.Tensor(sim)
        # sim_train = torch.Tensor(sim_train)
        # sim_train= torch.autograd.Variable(sim_train, requires_grad=True)
        # loss_train = criterion(self.sim_train, train_labels.float())


        return x,self.sim_train


class GCN_pia(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        self.gc2 = GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        return F.log_softmax(x, dim=1), embed2

class GCN_pia1(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia1, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        # self.gc2= GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        # x = F.relu(x)
        # x = F.dropout(x, self.dropout, training=self.training)
        # x = self.gc2(x, adj)
        # embed2 = x
        return F.log_softmax(x, dim=1),embed1



class GCN_pia_unlearn_baseline(nn.Module):
    def __init__(self,npara, nfeat, nhid, nclass, dropout):
        super(GCN_pia_unlearn, self).__init__()

        npara1=np.array(npara[0])
        npara2 = np.array(npara[1])

        self.gc1 = GraphConvolution_pia_unlearn1(npara1,npara2,nfeat, nhid)

        npara1 = np.array(npara[2])
        npara2 = np.array(npara[3])

        self.gc2= GraphConvolution_pia_unlearn1(npara1,npara2,nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        return F.log_softmax(x, dim=1),embed1,embed2




class GCN_feature_selection(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout,featuresSelected,):
        super(GCN_feature_selection, self).__init__()

        self.gc1 = GraphConvolution_feature_selection(nfeat, nhid,featuresSelected, True)
        self.gc2= GraphConvolution_feature_selection(nhid, nclass, featuresSelected,False)
        self.dropout = dropout

    def forward(self, x, adj,temp):
        x,featureSelector=self.gc1(x, adj,True,temp)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj,False,temp)
        embed2 = x
        return F.log_softmax(x, dim=1),embed1,embed2,featureSelector


class GCN_pia2(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia2, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        self.gc2 = GraphConvolution_pia(nhid, nhid)
        self.gc3= GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc3(x, adj)
        embed3 = x
        return F.log_softmax(x, dim=1),embed1,embed2,embed3


class GCN_pia3(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia3, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        self.gc2 = GraphConvolution_pia(nhid, nhid)
        self.gc3 = GraphConvolution_pia(nhid, nhid)
        self.gc4= GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc3(x, adj)
        embed3 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc4(x, adj)
        embed4 = x
        return F.log_softmax(x, dim=1),embed1,embed2,embed3,embed4

class GCN_pia4(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia4, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        self.gc2 = GraphConvolution_pia(nhid, nhid)
        self.gc3 = GraphConvolution_pia(nhid, nhid)
        self.gc4 = GraphConvolution_pia(nhid, nhid)
        self.gc5= GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc3(x, adj)
        embed3 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc4(x, adj)
        embed4 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc5(x, adj)
        embed5 = x
        return F.log_softmax(x, dim=1),embed1,embed2,embed3,embed4,embed5

class GCN_pia6(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia6, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        self.gc2 = GraphConvolution_pia(nhid, nhid)
        self.gc3 = GraphConvolution_pia(nhid, nhid)
        self.gc4 = GraphConvolution_pia(nhid, nhid)
        self.gc5 = GraphConvolution_pia(nhid, nhid)
        self.gc6 = GraphConvolution_pia(nhid, nhid)
        self.gc7 = GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc3(x, adj)
        embed3 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc4(x, adj)
        embed4 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc5(x, adj)
        embed5 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc6(x, adj)
        embed6 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc7(x, adj)
        embed7 = x
        return F.log_softmax(x, dim=1), embed1, embed2, embed3, embed4, embed5, embed6, embed7


class GCN_pia8(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia8, self).__init__()

        self.gc1 = GraphConvolution_pia(nfeat, nhid)
        self.gc2 = GraphConvolution_pia(nhid, nhid)
        self.gc3 = GraphConvolution_pia(nhid, nhid)
        self.gc4 = GraphConvolution_pia(nhid, nhid)
        self.gc5 = GraphConvolution_pia(nhid, nhid)
        self.gc6 = GraphConvolution_pia(nhid, nhid)
        self.gc7 = GraphConvolution_pia(nhid, nhid)
        self.gc8 = GraphConvolution_pia(nhid, nhid)
        self.gc9= GraphConvolution_pia(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc3(x, adj)
        embed3 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc4(x, adj)
        embed4 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc5(x, adj)
        embed5 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc6(x, adj)
        embed6 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc7(x, adj)
        embed7 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc8(x, adj)
        embed8 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc9(x, adj)
        embed9 = x
        return F.log_softmax(x, dim=1),embed1,embed2,embed3,embed4,embed5,embed6,embed7,embed8,embed9




class GCN_pia2_unlearn(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia2_unlearn, self).__init__()

        self.gc1 = GraphConvolution_pia_unlearn(nfeat, nhid)
        self.gc2 = GraphConvolution_pia_unlearn(nhid, nhid)
        self.gc3= GraphConvolution_pia_unlearn(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc3(x, adj)
        embed3 = x
        return F.log_softmax(x, dim=1),embed1,embed2,embed3


class GCN_pia_unlearn(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN_pia2_unlearn, self).__init__()

        self.gc1 = GraphConvolution_pia_unlearn(nfeat, nhid)
        self.gc2= GraphConvolution_pia_unlearn(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x=self.gc1(x, adj)
        embed1 = x
        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        embed2 = x
        return F.log_softmax(x, dim=1),embed1,embed2
    
class Net(nn.Module):
    def __init__(self, in_dim, h_dim, dropout):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(in_dim, h_dim[0])
        self.fc2 = nn.Linear(h_dim[0], h_dim[1])
        self.fc3 = nn.Linear(h_dim[1], h_dim[2])
        self.fc4 = nn.Linear(h_dim[2], h_dim[3])
        self.fc5 = nn.Linear(h_dim[3], 3)
        self.dropout = dropout
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout)
        x = F.dropout(F.relu(self.fc2(x)), p=self.dropout)
        x = F.dropout(F.relu(self.fc3(x)), p=self.dropout)
        x = F.dropout(F.relu(self.fc4(x)), p=self.dropout)
        x = self.fc5(x)
        x = self.softmax(x)
        return x

class GCN_Net(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GCN_Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, args.hidden)
        self.conv2 = GCNConv(args.hidden, dataset.num_classes)
        self.dropout = args.dropout

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        embed2 = x
        return F.log_softmax(x, dim=1), embed2

class GAT_Net(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GAT_Net, self).__init__()
        self.conv1 = GATConv(
            dataset.num_features,
            args.hidden,
            heads=args.heads,
            dropout=args.dropout)
        self.conv2 = GATConv(
            args.hidden * args.heads,
            dataset.num_classes,
            heads=args.output_heads,
            concat=False,
            dropout=args.dropout)
        self.dropout = args.dropout

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        embed2 = x
        return F.log_softmax(x, dim=1), embed2

class SGC_Net(torch.nn.Module):
    def __init__(self, dataset, args):
        super(SGC_Net, self).__init__()
        self.conv1 = SGConv(dataset.num_node_features, args.hidden, K=2, cached=False)
        self.conv2 = SGConv(args.hidden, dataset.num_classes, K=2, cached=False)
        self.dropout = args.dropout


    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        embed2 = x
        return F.log_softmax(x, dim=1), embed2

class SAGE_Net(torch.nn.Module):
    def __init__(self, dataset, args):
        super(SAGE_Net, self).__init__()
        self.num_layers = 2
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(dataset.num_node_features, 256))
        self.convs.append(SAGEConv(256, dataset.num_classes))
        self.dropout = args.dropout


    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()  

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, (edge_ind, _, size) in enumerate(edge_index):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_ind)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        embed2 = x
        return x.log_softmax(dim=-1), embed2
    
class GPR_prop(MessagePassing):
    '''
    propagation class for GPR_GNN
    '''
    def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
        super(GPR_prop, self).__init__(aggr='add', **kwargs)
        self.K = K
        self.Init = Init
        self.alpha = alpha
        self.Gamma = Gamma

        assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
        if Init == 'SGC':
            # SGC-like, note that in this case, alpha has to be a integer. It means where the peak at when initializing GPR weights.
            TEMP = 0.0*np.ones(K+1)
            TEMP[alpha] = 1.0
        elif Init == 'PPR':
            # PPR-like
            TEMP = alpha*(1-alpha)**np.arange(K+1)
            TEMP[-1] = (1-alpha)**K
        elif Init == 'NPPR':
            # Negative PPR
            TEMP = (alpha)**np.arange(K+1)
            TEMP = TEMP/np.sum(np.abs(TEMP))
        elif Init == 'Random':
            # Random
            bound = np.sqrt(3/(K+1))
            TEMP = np.random.uniform(-bound, bound, K+1)
            TEMP = TEMP/np.sum(np.abs(TEMP))
        elif Init == 'WS':
            # Specify Gamma
            TEMP = Gamma

        self.temp = Parameter(torch.tensor(TEMP))

    def reset_parameters(self):
        torch.nn.init.zeros_(self.temp)
        if self.Init == 'SGC':
            self.temp.data[self.alpha]= 1.0
        elif self.Init == 'PPR':
            for k in range(self.K+1):
                self.temp.data[k] = self.alpha*(1-self.alpha)**k
            self.temp.data[-1] = (1-self.alpha)**self.K
        elif self.Init == 'NPPR':
            for k in range(self.K+1):
                self.temp.data[k] = self.alpha**k
            self.temp.data = self.temp.data/torch.sum(torch.abs(self.temp.data))
        elif self.Init == 'Random':
            bound = np.sqrt(3/(self.K+1))
            torch.nn.init.uniform_(self.temp,-bound,bound)
            self.temp.data = self.temp.data/torch.sum(torch.abs(self.temp.data))
        elif self.Init == 'WS':
            self.temp.data = self.Gamma

    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = gcn_norm(
            edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)

        hidden = x*(self.temp[0])
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            gamma = self.temp[k+1]
            hidden = hidden + gamma*x
        return hidden

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def __repr__(self):
        return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
                                          self.temp)
    
class GPRGNN(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GPRGNN, self).__init__()
        self.lin1 = Linear(dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, dataset.num_classes)

        if args.ppnp == 'PPNP':
            self.prop1 = APPNP(args.K, args.alpha)
        elif args.ppnp == 'GPR_prop':
            self.prop1 = GPR_prop(args.K, args.alpha, args.Init, args.Gamma)

        self.Init = args.Init
        self.dprate = args.dprate
        self.dropout = args.dropout

    def reset_parameters(self):
        self.prop1.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)

        if self.dprate == 0.0:
            x = self.prop1(x, edge_index)
            embed2 = x
            return F.log_softmax(x, dim=1), embed2
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index)
            embed2 = x
            return F.log_softmax(x, dim=1), embed2
        
class GCN_PLUS_Net(torch.nn.Module):
    def __init__(self, in_channels, hidden, out_channels, num_layers,
                dropout):
        super(GCN_PLUS_Net, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden, hidden, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden))
        self.convs.append(GCNConv(hidden, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        embed2 = x
        return x.log_softmax(dim=-1), embed2

class GAT_PLUS_Net(torch.nn.Module):
    def __init__(self, in_channels, hidden, out_channels, num_layers,
                dropout, heads, output_heads):
        super(GAT_PLUS_Net, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden, heads=heads))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden*heads))
        for _ in range(num_layers - 2):
            self.convs.append(
                GATConv(hidden*heads, hidden, heads=heads))
            self.bns.append(torch.nn.BatchNorm1d(hidden*heads))
        self.convs.append(GATConv(hidden*heads, out_channels, concat=False, heads=output_heads))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        embed2 = x
        return x.log_softmax(dim=-1), embed2

class SGC_PLUS_Net(torch.nn.Module):
    def __init__(self, in_channels, hidden, out_channels, num_layers,
                dropout):
        super(SGC_PLUS_Net, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SGConv(in_channels, hidden, K=2, cached=False))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden))
        for _ in range(num_layers - 2):
            self.convs.append(
                SGConv(hidden, hidden, K=2, cached=False))
            self.bns.append(torch.nn.BatchNorm1d(hidden))
        self.convs.append(SGConv(hidden, out_channels, K=2, cached=False))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        embed2 = x
        return x.log_softmax(dim=-1), embed2

class NLGCN(torch.nn.Module):
    def __init__(self, dataset, args):
        super(NLGCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, args.hidden)
        self.conv2 = GCNConv(args.hidden, dataset.num_classes)
        self.proj = nn.Linear(dataset.num_classes, 1)
        self.kernel = 5
        self.conv1d = nn.Conv1d(dataset.num_classes, dataset.num_classes, self.kernel, padding=int((self.kernel-1)/2))
        self.conv1d_2 = nn.Conv1d(dataset.num_classes, dataset.num_classes, self.kernel, padding=int((self.kernel-1)/2))
        self.lin = nn.Linear(2*dataset.num_classes, dataset.num_classes)
        self.dropout1 = args.dropout1
        self.dropout2 = args.dropout2

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.proj.reset_parameters()
        self.conv1d.reset_parameters()
        self.conv1d_2.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout1, training=self.training)
        x1 = self.conv2(x, edge_index)
        
        g_score = self.proj(x1)  # [num_nodes, 1]
        g_score_sorted, sort_idx = torch.sort(g_score, dim=0)
        _, inverse_idx = torch.sort(sort_idx, dim=0)
        
        sorted_x = g_score_sorted*x1[sort_idx].squeeze()
        sorted_x = torch.transpose(sorted_x, 0, 1).unsqueeze(0) # [1, dataset.num_classes, num_nodes]
        sorted_x = F.relu(self.conv1d(sorted_x))
        sorted_x = F.dropout(sorted_x, p=self.dropout2, training=self.training)
        sorted_x = self.conv1d_2(sorted_x)
        sorted_x = torch.transpose(sorted_x.squeeze(), 0, 1) # [num_nodes, dataset.num_classes]
        x2 = sorted_x[inverse_idx].squeeze()  # [num_nodes, dataset.num_classes]
        
        out = torch.cat([x1,x2], dim=1)
        out = self.lin(out)
        embed2 = out
        return F.log_softmax(out, dim=1), embed2
    

class NLGAT(torch.nn.Module):
    def __init__(self, dataset, args):
        super(NLGAT, self).__init__()
        self.conv1 = GATConv(
            dataset.num_features,
            args.hidden,
            heads=args.heads,
            dropout=args.dropout1)
        self.conv2 = GATConv(
            args.hidden * args.heads,
            dataset.num_classes,
            heads=args.output_heads,
            concat=False,
            dropout=args.dropout1)
        self.proj = nn.Linear(dataset.num_classes, 1)
        self.kernel = 5
        self.conv1d = nn.Conv1d(dataset.num_classes, dataset.num_classes, self.kernel, padding=int((self.kernel-1)/2))
        self.conv1d_2 = nn.Conv1d(dataset.num_classes, dataset.num_classes, self.kernel, padding=int((self.kernel-1)/2))
        self.lin = nn.Linear(2*dataset.num_classes, dataset.num_classes)
        self.dropout1 = args.dropout1
        self.dropout2 = args.dropout2

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.proj.reset_parameters()
        self.conv1d.reset_parameters()
        self.conv1d_2.reset_parameters()
        self.lin.reset_parameters()
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout1, training=self.training)
        x1 = self.conv2(x, edge_index)

        g_score = self.proj(x1)  # [num_nodes, 1]
        g_score_sorted, sort_idx = torch.sort(g_score, dim=0)
        _, inverse_idx = torch.sort(sort_idx, dim=0)
        
        sorted_x = g_score_sorted*x1[sort_idx].squeeze()
        sorted_x = torch.transpose(sorted_x, 0, 1).unsqueeze(0) # [1, dataset.num_classes, num_nodes]
        sorted_x = F.relu(self.conv1d(sorted_x))
        sorted_x = F.dropout(sorted_x, p=self.dropout2, training=self.training)
        sorted_x = self.conv1d_2(sorted_x)
        sorted_x = torch.transpose(sorted_x.squeeze(), 0, 1) # [num_nodes, dataset.num_classes]
        x2 = sorted_x[inverse_idx].squeeze()  # [num_nodes, dataset.num_classes]
        
        out = torch.cat([x1,x2], dim=1)
        out = self.lin(out)
        embed2 = out
        return F.log_softmax(out, dim=1), embed2
    
   
class NLMLP(torch.nn.Module):
    def __init__(self, dataset, args):
        super(NLMLP, self).__init__()
        self.lin1 = nn.Linear(dataset.num_features, args.hidden)
        self.lin2 = nn.Linear(args.hidden, dataset.num_classes)
        self.proj = nn.Linear(dataset.num_classes, 1)
        self.kernel = args.kernel
        self.conv1d = nn.Conv1d(dataset.num_classes, dataset.num_classes, self.kernel, padding=int((self.kernel-1)/2))
        self.conv1d_2 = nn.Conv1d(dataset.num_classes, dataset.num_classes, self.kernel, padding=int((self.kernel-1)/2))
        self.lin = nn.Linear(2*dataset.num_classes, dataset.num_classes)
        self.dropout1 = args.dropout1
        self.dropout2 = args.dropout2

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.proj.reset_parameters()
        self.conv1d.reset_parameters()
        self.conv1d_2.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=self.dropout1, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout1, training=self.training)
        x1 = self.lin2(x)
        
        g_score = self.proj(x1)  # [num_nodes, 1]
        g_score_sorted, sort_idx = torch.sort(g_score, dim=0)
        _, inverse_idx = torch.sort(sort_idx, dim=0)
        
        sorted_x = g_score_sorted*x1[sort_idx].squeeze()
        sorted_x = torch.transpose(sorted_x, 0, 1).unsqueeze(0) # [1, dataset.num_classes, num_nodes]
        sorted_x = F.relu(self.conv1d(sorted_x))
        sorted_x = F.dropout(sorted_x, p=self.dropout2, training=self.training)
        sorted_x = self.conv1d_2(sorted_x)
        sorted_x = torch.transpose(sorted_x.squeeze(), 0, 1) # [num_nodes, dataset.num_classes]
        x2 = sorted_x[inverse_idx].squeeze()  # [num_nodes, dataset.num_classes]
        
        out = torch.cat([x1,x2], dim=1)
        out = self.lin(out)
        embed2 = out
        return F.log_softmax(out, dim=1), embed2

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_layers, num_classes):
        super(MLP, self).__init__()
        layers = []
        in_dim = input_dim
        for hidden_dim in hidden_layers:
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(in_dim, num_classes))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)