"""
General GNN framework
"""
from copy import deepcopy as c
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import BatchNorm, LayerNorm, InstanceNorm, PairNorm, GraphSizeNorm, global_add_pool
from layers.gine import GINEConv
from layers.feature_encoder import FeatureConcatEncoder


def clones(module, N):
    """Layer clone function, used for concise code writing
    Args:
        module (nn.Module): the layer want to clone
        N (int): the time of clone
    """
    return nn.ModuleList(c(module) for _ in range(N))


class GNN(nn.Module):
    """A generalized GNN framework
    Args:
        num_layer (int): the number of GNN layer
        gnn_layer (nn.Module): gnn layer used in GNN model
        init_emb (nn.Module): initial node feature encoding
        init_k_emb (nn.Module): initial kernel encoding
        num_hop1_edge (int): number of edge type at 1 hop
        JK (str):method of jumping knowledge, last,concat,max or sum
        norm_type (str): method of normalization, batch or layer
        virtual_node (bool): whether to add virtual node in the model
        residual (bool): whether to add residual connection
        use_rd (bool): whether to add resistance distance as additional feature
        drop_prob (float): dropout rate
    """

    def __init__(self, num_layer, gnn_layer, init_emb, init_k_emb, num_hop1_edge, 
                 JK="last", norm_type="batch", virtual_node=True,
                 residual=False, use_rd=False, 
                 drop_prob=0.1):
        super(GNN, self).__init__()
        
        '''
        gnn_layer: layer = make_gnn_layer(args) (KPGCNConv, KPGINConv, KPGraphSAGEConv, KPGINPlusConv)
        init_emb = LinearEncoder(args.input_size, args.hidden_size)
        
        '''
    
        
        self.num_layer = num_layer
        self.hidden_size = gnn_layer.output_size 
        self.K = gnn_layer.K
        self.output_dk = gnn_layer.output_dk
        self.dropout = nn.Dropout(drop_prob)
        self.JK = JK
        self.residual = residual
        self.use_rd = use_rd
        self.virtual_node = virtual_node

        if self.JK == "concat":
            self.output_proj = nn.Sequential(nn.Linear((self.num_layer + 1) * self.hidden_size, self.hidden_size),
                                             nn.ReLU(), nn.Dropout(drop_prob))
        else:
            self.output_proj = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
                                             nn.Dropout(drop_prob))

        if self.JK == "attention":
            self.attention_lstm = nn.LSTM(self.hidden_size, self.num_layer, 1, batch_first=True, bidirectional=True,
                                          dropout=0.)


        # embedding start from 1
        self.init_proj = init_emb # LinearEncoder
        
        self.init_kernel_proj = init_k_emb
        
        if self.use_rd:
            self.rd_projection = torch.nn.Linear(1, self.hidden_size)
        if self.virtual_node:
            # set the initial virtual node embedding to 0.
            self.virtualnode_embedding = torch.nn.Embedding(1, self.hidden_size)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) 

            # List of MLPs to transform virtual node at every layer
            self.mlp_virtualnode_list = torch.nn.ModuleList()
            for layer in range(num_layer - 1):
                self.mlp_virtualnode_list.append(torch.nn.Sequential(
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU()))

       
        # gnn layer list
        self.gnns = clones(gnn_layer, num_layer)
        # norm list
        if norm_type == "Batch":
            self.norms = clones(BatchNorm(self.hidden_size), num_layer)
        elif norm_type == "Layer":
            self.norms = clones(LayerNorm(self.hidden_size), num_layer)
        elif norm_type == "Instance":
            self.norms = clones(InstanceNorm(self.hidden_size), num_layer)
        elif norm_type == "GraphSize":
            self.norms = clones(GraphSizeNorm(), num_layer)
        elif norm_type == "Pair":
            self.norms = clones(PairNorm(), num_layer)
        else:
            raise ValueError("Not supported norm method")

        self.reset_parameters()

    def weights_init(self, m):
        if hasattr(m, "reset_parameters"):
            m.reset_parameters()

    def reset_parameters(self):
        self.init_proj.reset_parameters()
        for g in self.gnns:
            g.reset_parameters()
        if self.JK == "attention":
            self.attention_lstm.reset_parameters()

        self.output_proj.apply(self.weights_init)
        if self.use_rd:
            self.rd_projection.reset_parameters()
        if self.virtual_node:
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
            self.mlp_virtualnode_list.apply(self.weights_init)


    def forward(self, data):
        edge_index, edge_attr, batch = data.edge_index, data.edge_attr, data.batch
        #print(batch)
        
        if "pe_attr" in data:
            pe_attr = data.pe_attr
        else:
            pe_attr = None

        if "rd" in data:
            rd = data.rd
        else:
            rd = None

        # initial projection

        x = self.init_proj(data).squeeze()
        ker_emb = self.init_kernel_proj(data)
        num_nodes = x.size(0)

        if self.use_rd and rd is not None:
            rd_proj = self.rd_projection(rd).squeeze()
            x = x + rd_proj

        
        if self.virtual_node: 
            virtualnode_embedding = self.virtualnode_embedding(
                torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)
            ) 

        # forward in gnn layer
        h_list = [x]
        for l in range(self.num_layer):

            if self.virtual_node:
                h_list[l] = h_list[l] + virtualnode_embedding[batch] 
            h = self.gnns[l](h_list[l], ker_emb, edge_index, edge_attr, pe_attr)
            h = self.norms[l](h)
            # if not the last gnn layer, add dropout layer
            if l != self.num_layer - 1:
                h = self.dropout(h)

            if self.residual:
                h = h + h_list[l]

            h_list.append(h)

            if self.virtual_node:
                # update the virtual nodes
                if l < self.num_layer - 1:
                    virtualnode_embedding_temp = global_add_pool(
                        h_list[l], batch
                    ) + virtualnode_embedding
                    # transform virtual nodes using MLP

                    if self.residual:
                        virtualnode_embedding = virtualnode_embedding + self.dropout(
                            self.mlp_virtualnode_list[l](virtualnode_embedding_temp))
                    else:
                        virtualnode_embedding = self.dropout(self.mlp_virtualnode_list[l](virtualnode_embedding_temp))  

        # JK connection
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze(-1) for h in h_list]
            node_representation = F.max_pool1d(torch.cat(h_list, dim=-1), kernel_size=self.num_layer + 1).squeeze()
        elif self.JK == "sum":
            h_list = [h.unsqueeze(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)
        elif self.JK == "attention":
            h_list = [h.unsqueeze(0) for h in h_list]
            h_list = torch.cat(h_list, dim=0).transpose(0, 1)  # N *num_layer * H
            self.attention_lstm.flatten_parameters()
            attention_score, _ = self.attention_lstm(h_list)  # N * num_layer * 2*num_layer
            attention_score = torch.softmax(torch.sum(attention_score, dim=-1), dim=1).unsqueeze(
                -1)  # N * num_layer  * 1
            node_representation = torch.sum(h_list * attention_score, dim=1)

        return self.output_proj(node_representation)


class GNNPlus(nn.Module):
    """A generalized GNN framework with GINE+ color refinement
    Args:
        num_layer (int): the number of GNN layer
        gnn_layer (nn.Module): gnn layer used in GNN model
        init_emb (nn.Module): initial node feature encoding
        num_hop1_edge (int): number of edge type at 1 hop
        JK (str):method of jumping knowledge, last,concat,max or sum
        norm_type (str): method of normalization, batch or layer
        virtual_node (bool): whether to add virtual node in the model
        residual (bool): whether to add residual connection
        use_rd (bool): whether to add resistance distance as additional feature
        drop_prob (float): dropout rate
    """

    def __init__(self, num_layer, gnn_layer, init_emb, init_k_emb, num_hop1_edge, 
                 JK="last", norm_type="batch", virtual_node=True,
                 residual=False, use_rd=False, 
                 drop_prob=0.1):
        super(GNNPlus, self).__init__()
        self.num_layer = num_layer
        self.hidden_size = gnn_layer[-1].output_size
        self.K = gnn_layer[-1].K
        #for GNN+, number of layer must be at least equal to K to get all information up to K-hop.
        assert num_layer >= self.K
        self.dropout = nn.Dropout(drop_prob)
        self.JK = JK
        self.residual = residual
        self.use_rd = use_rd
        self.virtual_node = virtual_node
        if self.JK == "concat":
            self.output_proj = nn.Sequential(nn.Linear((self.num_layer + 1) * self.hidden_size, self.hidden_size),
                                             nn.ReLU(), nn.Dropout(drop_prob))
        else:
            self.output_proj = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
                                             nn.Dropout(drop_prob))

        if self.JK == "attention":
            self.attention_lstm = nn.LSTM(self.hidden_size, self.num_layer, 1, batch_first=True, bidirectional=True,
                                          dropout=0.)



        # embedding start from 1
        self.init_proj = init_emb
        self.init_kernel_proj = init_k_emb

        if self.use_rd:
            self.rd_projection = nn.Linear(1, self.hidden_size)

        if self.virtual_node:
            # set the initial virtual node embedding to 0.
            self.virtualnode_embedding = torch.nn.Embedding(1, self.hidden_size)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            # List of MLPs to transform virtual node at every layer
            self.mlp_virtualnode_list = torch.nn.ModuleList()
            for layer in range(num_layer - 1):
                self.mlp_virtualnode_list.append(torch.nn.Sequential(
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU()))

       
        # gnn layer list
        self.gnns = nn.ModuleList(gnn_layer)

        # norm list
        if norm_type == "Batch":
            self.norms = clones(BatchNorm(self.hidden_size), num_layer)
        elif norm_type == "Layer":
            self.norms = clones(LayerNorm(self.hidden_size), num_layer)
        elif norm_type == "Instance":
            self.norms = clones(InstanceNorm(self.hidden_size), num_layer)
        elif norm_type == "GraphSize":
            self.norms = clones(GraphSizeNorm(), num_layer)
        elif norm_type == "Pair":
            self.norms = clones(PairNorm(), num_layer)
        else:
            raise ValueError("Not supported norm method")

        self.reset_parameters()

    def weights_init(self, m):
        if hasattr(m, "reset_parameters"):
            m.reset_parameters()

    def reset_parameters(self):
        self.init_proj.reset_parameters()
        if self.JK == "attention":
            self.attention_lstm.reset_parameters()
        self.output_proj.apply(self.weights_init)
        if self.use_rd:
            self.rd_projection.reset_parameters()
        if self.virtual_node:
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
            self.mlp_virtualnode_list.apply(self.weights_init)

        for g in self.gnns:
            g.reset_parameters()

    def forward(self, data):
        edge_index, edge_attr, batch = data.edge_index, data.edge_attr, data.batch
        
        if "pe_attr" in data:
            pe_attr = data.pe_attr
        else:
            pe_attr = None

        if "rd" in data:
            rd = data.rd
        else:
            rd = None

        # initial projection
        x = self.init_proj(data).squeeze()
        ker_emb = self.init_kernel_proj(data)

        num_nodes = x.size(0)

        if self.use_rd and rd is not None:
            rd_proj = self.rd_projection(rd).squeeze()
            x = rd_proj + x

        

        if self.virtual_node:
            virtualnode_embedding = self.virtualnode_embedding(
                torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)
            )

        # forward in gnn layer
        h_list = [x]
        last_h = x
        for l in range(self.num_layer):
            if self.virtual_node:
                h_list[l] = h_list[l] + virtualnode_embedding[batch]
            x_list = []
            
            end = -1 if l + 1 <= self.K else l - self.K # not copy like KPGIN. concat each layer
            for j in range(l, end, -1):
                x_list.append(h_list[j].unsqueeze(1)) 
            x = torch.cat(x_list, dim=1)
            k = l + 1 if l + 1 <= self.K else self.K
            

            if pe_attr is not None:
                pek = pe_attr[:, :k - 1]
            else:
                pek = None

            k_layer = ker_emb[:,:k,:]
            h = self.gnns[l](x, k_layer, edge_index, edge_attr[:, :k], pek)
            h = self.norms[l](h)
            # if not the last gnn layer, add dropout layer
            if l != self.num_layer-1:
                h = self.dropout(h)
            if self.residual:
                h = h + last_h
                last_h = h

            h_list.append(h)

            if self.virtual_node:
                # update the virtual nodes
                if l < self.num_layer - 1:
                    virtualnode_embedding_temp = global_add_pool(
                        h_list[l], batch
                    ) + virtualnode_embedding
                    # transform virtual nodes using MLP

                    if self.residual:
                        virtualnode_embedding = virtualnode_embedding + self.dropout(
                            self.mlp_virtualnode_list[l](virtualnode_embedding_temp))
                    else:
                        virtualnode_embedding = self.dropout(self.mlp_virtualnode_list[l](virtualnode_embedding_temp))

        # JK connection
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze(-1) for h in h_list]
            node_representation = F.max_pool1d(torch.cat(h_list, dim=-1), kernel_size=self.num_layer + 1).squeeze()
        elif self.JK == "sum":
            h_list = [h.unsqueeze(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)
        elif self.JK == "attention":
            h_list = [h.unsqueeze(0) for h in h_list]
            h_list = torch.cat(h_list, dim=0).transpose(0, 1)  # N *num_layer * H
            self.attention_lstm.flatten_parameters()
            attention_score, _ = self.attention_lstm(h_list)  # N * num_layer * 2*num_layer
            attention_score = torch.softmax(torch.sum(attention_score, dim=-1), dim=1).unsqueeze(
                -1)  # N * num_layer  * 1
            node_representation = torch.sum(h_list * attention_score, dim=1)

        return self.output_proj(node_representation)
    
    
    
class GNNEPlus(nn.Module):
    """A generalized GNN framework with GINE+ color refinement
    Args:
        num_layer (int): the number of GNN layer
        gnn_layer (nn.Module): gnn layer used in GNN model
        init_emb (nn.Module): initial node feature encoding
        num_hop1_edge (int): number of edge type at 1 hop
        JK (str):method of jumping knowledge, last,concat,max or sum
        norm_type (str): method of normalization, batch or layer
        virtual_node (bool): whether to add virtual node in the model
        residual (bool): whether to add residual connection
        use_rd (bool): whether to add resistance distance as additional feature
        drop_prob (float): dropout rate
    """

    def __init__(self, num_layer, gnn_layer, init_emb, init_k_emb, num_hop1_edge, 
                 JK="last", norm_type="batch", virtual_node=True,
                 residual=False, use_rd=False, 
                 drop_prob=0.1):
        super(GNNEPlus, self).__init__()
        self.num_layer = num_layer
        self.hidden_size = gnn_layer[-1].output_size
        self.K = gnn_layer[-1].K
        #for GNN+, number of layer must be at least equal to K to get all information up to K-hop.
        assert num_layer >= self.K
        self.dropout = nn.Dropout(drop_prob)
        self.JK = JK
        self.residual = residual
        self.use_rd = use_rd
        self.virtual_node = virtual_node
        if self.JK == "concat":
            self.output_proj = nn.Sequential(nn.Linear((self.num_layer + 1) * self.hidden_size, self.hidden_size),
                                             nn.ReLU(), nn.Dropout(drop_prob))
        else:
            self.output_proj = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
                                             nn.Dropout(drop_prob))

        if self.JK == "attention":
            self.attention_lstm = nn.LSTM(self.hidden_size, self.num_layer, 1, batch_first=True, bidirectional=True,
                                          dropout=0.)



        # embedding start from 1
        self.init_proj = init_emb
        self.init_kernel_proj = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.init_kernel_proj.append(torch.nn.Linear(1, self.hidden_size))
            
        if self.use_rd:
            self.rd_projection = nn.Linear(1, self.hidden_size)

        if self.virtual_node:
            # set the initial virtual node embedding to 0.
            self.virtualnode_embedding = torch.nn.Embedding(1, self.hidden_size)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            # List of MLPs to transform virtual node at every layer
            self.mlp_virtualnode_list = torch.nn.ModuleList()
            for layer in range(num_layer - 1):
                self.mlp_virtualnode_list.append(torch.nn.Sequential(
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU()))

       
        # gnn layer list
        self.gnns = nn.ModuleList(gnn_layer)

        # norm list
        if norm_type == "Batch":
            self.norms = clones(BatchNorm(self.hidden_size), num_layer)
        elif norm_type == "Layer":
            self.norms = clones(LayerNorm(self.hidden_size), num_layer)
        elif norm_type == "Instance":
            self.norms = clones(InstanceNorm(self.hidden_size), num_layer)
        elif norm_type == "GraphSize":
            self.norms = clones(GraphSizeNorm(), num_layer)
        elif norm_type == "Pair":
            self.norms = clones(PairNorm(), num_layer)
        else:
            raise ValueError("Not supported norm method")

        self.reset_parameters()

    def weights_init(self, m):
        if hasattr(m, "reset_parameters"):
            m.reset_parameters()

    def reset_parameters(self):
        self.init_proj.reset_parameters()
        if self.JK == "attention":
            self.attention_lstm.reset_parameters()
        self.output_proj.apply(self.weights_init)
        if self.use_rd:
            self.rd_projection.reset_parameters()
        if self.virtual_node:
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
            self.mlp_virtualnode_list.apply(self.weights_init)

        for g in self.gnns:
            g.reset_parameters()
            
        for ker in self.init_kernel_proj:
            ker.reset_parameters()

    def forward(self, data):
        edge_index, edge_attr, batch = data.edge_index, data.edge_attr, data.batch
        
        if "pe_attr" in data:
            pe_attr = data.pe_attr
        else:
            pe_attr = None

        if "rd" in data:
            rd = data.rd
        else:
            rd = None

        # initial projection
        x = self.init_proj(data).squeeze()
        #ker_emb = self.init_kernel_proj(data)

        num_nodes = x.size(0)

        if self.use_rd and rd is not None:
            rd_proj = self.rd_projection(rd).squeeze()
            x = rd_proj + x

        

        if self.virtual_node:
            virtualnode_embedding = self.virtualnode_embedding(
                torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)
            )

        # forward in gnn layer
        h_list = [x]
        last_h = x
        ker_list = []
    
        for l in range(self.num_layer):
            if self.virtual_node:
                h_list[l] = h_list[l] + virtualnode_embedding[batch]
            x_list = []

            
            ker_emb = self.init_kernel_proj[l](data.node_entropy[:,l].unsqueeze(1))
            ker_list.append(ker_emb.unsqueeze(1))
            
            end = -1 if l + 1 <= self.K else l - self.K # not copy like KPGIN. concat each layer
            for j in range(l, end, -1):
                x_list.append(h_list[j].unsqueeze(1)) 
            ker = torch.cat(ker_list, dim=1)
            x = torch.cat(x_list, dim=1)
            k = l + 1 if l + 1 <= self.K else self.K
            

            if pe_attr is not None:
                pek = pe_attr[:, :k - 1]
            else:
                pek = None

            #k_layer = ker_emb[:,:k,:]
            h = self.gnns[l](x, ker, edge_index, edge_attr[:, :k], pek)
            h = self.norms[l](h)
            # if not the last gnn layer, add dropout layer
            if l != self.num_layer-1:
                h = self.dropout(h)
            if self.residual:
                h = h + last_h
                last_h = h

            h_list.append(h)

            if self.virtual_node:
                # update the virtual nodes
                if l < self.num_layer - 1:
                    virtualnode_embedding_temp = global_add_pool(
                        h_list[l], batch
                    ) + virtualnode_embedding
                    # transform virtual nodes using MLP

                    if self.residual:
                        virtualnode_embedding = virtualnode_embedding + self.dropout(
                            self.mlp_virtualnode_list[l](virtualnode_embedding_temp))
                    else:
                        virtualnode_embedding = self.dropout(self.mlp_virtualnode_list[l](virtualnode_embedding_temp))

        # JK connection
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze(-1) for h in h_list]
            node_representation = F.max_pool1d(torch.cat(h_list, dim=-1), kernel_size=self.num_layer + 1).squeeze()
        elif self.JK == "sum":
            h_list = [h.unsqueeze(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)
        elif self.JK == "attention":
            h_list = [h.unsqueeze(0) for h in h_list]
            h_list = torch.cat(h_list, dim=0).transpose(0, 1)  # N *num_layer * H
            self.attention_lstm.flatten_parameters()
            attention_score, _ = self.attention_lstm(h_list)  # N * num_layer * 2*num_layer
            attention_score = torch.softmax(torch.sum(attention_score, dim=-1), dim=1).unsqueeze(
                -1)  # N * num_layer  * 1
            node_representation = torch.sum(h_list * attention_score, dim=1)

        return self.output_proj(node_representation)



class GNNPrime(nn.Module):
    """A generalized GNN framework with l1 K-hop message passing and l2 GIN layer
    Args:
        num_layer (int): the number of GNN layer
        gnn_layer (nn.Module): gnn layer used in GNN model for K-hop message passing
        init_emb (nn.Module): initial node feature encoding
        num_hop1_edge (int): number of edge type at 1 hop
        num_l1_layer (int): the number of k-hop message passing layer
        JK (str):method of jumping knowledge, last,concat,max or sum
        norm_type (str): method of normalization, batch or layer
        virtual_node (bool): whether to add virtual node in the model
        residual (bool): whether to add residual connection
        use_rd (bool): whether to add resistance distance as additional feature
        drop_prob (float): dropout rate
    """

    def __init__(self, num_layer, gnn_layer, init_emb,init_k_emb, num_hop1_edge, num_l1_layer=1,
                 JK="last", norm_type="batch", virtual_node=True,
                 residual=False, use_rd=False,
                 drop_prob=0.1):
        super(GNNPrime, self).__init__()
        assert num_l1_layer > 0
        assert num_layer >= 2
        self.num_l1_layer = num_l1_layer
        self.num_l2_layer = num_layer - num_l1_layer
        self.num_layer = num_layer
        self.hidden_size = gnn_layer.output_size
        self.K = gnn_layer.K
        self.output_dk = gnn_layer.output_dk
        self.dropout = nn.Dropout(drop_prob)
        self.JK = JK
        self.residual = residual
        self.use_rd = use_rd
        self.virtual_node = virtual_node

        if self.JK == "concat":
            self.output_proj = nn.Sequential(nn.Linear((self.num_layer + 1) * self.hidden_size, self.hidden_size),
                                             nn.ReLU(), nn.Dropout(drop_prob))
        else:
            self.output_proj = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
                                             nn.Dropout(drop_prob))

        if self.JK == "attention":
            self.attention_lstm = nn.LSTM(self.hidden_size, self.num_layer, 1, batch_first=True, bidirectional=True,
                                          dropout=0.)


        self.init_proj = init_emb
        self.init_kernel_proj = init_k_emb
        
        if self.use_rd:
            self.rd_projection = torch.nn.Linear(1, self.hidden_size)
        if self.virtual_node:
            # set the initial virtual node embedding to 0.
            self.virtualnode_embedding = torch.nn.Embedding(1, self.hidden_size)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            # List of MLPs to transform virtual node at every layer
            self.mlp_virtualnode_list = torch.nn.ModuleList()
            for layer in range(self.num_layer - 1):
                self.mlp_virtualnode_list.append(torch.nn.Sequential(
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.BatchNorm1d(self.hidden_size),
                    torch.nn.ReLU()))

       

        # gnn layer list
        self.khop_gnns = clones(gnn_layer, num_l1_layer)

        self.gins = clones(GINEConv(self.hidden_size, self.hidden_size, num_hop1_edge=num_hop1_edge), self.num_l2_layer)
        # norm list
        if norm_type == "Batch":
            self.norms = clones(BatchNorm(self.hidden_size), self.num_layer)
        elif norm_type == "Layer":
            self.norms = clones(LayerNorm(self.hidden_size), self.num_layer)
        elif norm_type == "Instance":
            self.norms = clones(InstanceNorm(self.hidden_size), self.num_layer)
        elif norm_type == "GraphSize":
            self.norms = clones(GraphSizeNorm(), self.num_layer)
        elif norm_type == "Pair":
            self.norms = clones(PairNorm(), self.num_layer)
        else:
            raise ValueError("Not supported norm method")

        self.reset_parameters()

    def weights_init(self, m):
        if hasattr(m, "reset_parameters"):
            m.reset_parameters()

    def reset_parameters(self):
        self.init_proj.reset_parameters()
        for g in self.khop_gnns:
            g.reset_parameters()
        for g in self.gins:
            g.reset_parameters()
        if self.JK == "attention":
            self.attention_lstm.reset_parameters()

        self.output_proj.apply(self.weights_init)
        if self.use_rd:
            self.rd_projection.reset_parameters()
        if self.virtual_node:
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
            self.mlp_virtualnode_list.apply(self.weights_init)


    def forward(self, data):

        edge_index, edge_attr, batch = data.edge_index, data.edge_attr, data.batch
        

        if "pe_attr" in data:
            pe_attr = data.pe_attr
        else:
            pe_attr = None

        if "rd" in data:
            rd = data.rd
        else:
            rd = None

        # initial projection
        x = self.init_proj(data).squeeze() # init embedding: LinearEncoder(args.input_size, args.hidden_size)
        ker_emb = self.init_kernel_proj(data)
        
        num_nodes = x.size(0)

        if self.use_rd and rd is not None:
            rd_proj = self.rd_projection(rd).squeeze()
            x = x + rd_proj

       
        if self.virtual_node:
            virtualnode_embedding = self.virtualnode_embedding(
                torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)
            )

        # forward in gnn layer
        h_list = [x]
        for l in range(self.num_l1_layer):
            if self.virtual_node:
                h_list[l] = h_list[l] + virtualnode_embedding[batch]
            h = self.khop_gnns[l](h_list[l], ker_emb, edge_index, edge_attr, pe_attr)
            h = self.norms[l](h)
            h = self.dropout(h)
            if self.residual:
                h = h + h_list[l]
            h_list.append(h)

            if self.virtual_node:
                # update the virtual nodes
                if l < self.num_layer - 1:
                    virtualnode_embedding_temp = global_add_pool(
                        h_list[l], batch
                    ) + virtualnode_embedding
                    # transform virtual nodes using MLP

                    if self.residual:
                        virtualnode_embedding = virtualnode_embedding + self.dropout(
                            self.mlp_virtualnode_list[l](virtualnode_embedding_temp))
                    else:
                        virtualnode_embedding = self.dropout(self.mlp_virtualnode_list[l](virtualnode_embedding_temp))

        for l in range(self.num_l1_layer, self.num_layer):
            if self.virtual_node:
                h_list[l] = h_list[l] + virtualnode_embedding[batch]
                
            #print(ker_emb.shape)
            #h_list[l]  = h_list[l] + ker_emb[:, l-1, :]
            h = self.gins[l-self.num_l1_layer](h_list[l],edge_index, edge_attr[:,:1])
            h = self.norms[l](h)
            if l != self.num_layer - 1:
                h = self.dropout(h)
            if self.residual:
                h = h + h_list[l]
            h_list.append(h)

            if self.virtual_node:
                # update the virtual nodes
                if l < self.num_layer - 1:
                    virtualnode_embedding_temp = global_add_pool(
                        h_list[l], batch
                    ) + virtualnode_embedding
                    # transform virtual nodes using MLP

                    if self.residual:
                        virtualnode_embedding = virtualnode_embedding + self.dropout(
                            self.mlp_virtualnode_list[l](virtualnode_embedding_temp))
                    else:
                        virtualnode_embedding = self.dropout(self.mlp_virtualnode_list[l](virtualnode_embedding_temp))


        # JK connection
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze(-1) for h in h_list]
            node_representation = F.max_pool1d(torch.cat(h_list, dim=-1), kernel_size=self.num_layer + 1).squeeze()
        elif self.JK == "sum":
            h_list = [h.unsqueeze(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)
        elif self.JK == "attention":
            h_list = [h.unsqueeze(0) for h in h_list]
            h_list = torch.cat(h_list, dim=0).transpose(0, 1)  # N *num_layer * H
            self.attention_lstm.flatten_parameters()
            attention_score, _ = self.attention_lstm(h_list)  # N * num_layer * 2*num_layer
            attention_score = torch.softmax(torch.sum(attention_score, dim=-1), dim=1).unsqueeze(
                -1)  # N * num_layer  * 1
            node_representation = torch.sum(h_list * attention_score, dim=1)

        return self.output_proj(node_representation)

