import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.nn.functional as F
import igraph

# Some utility functions

NODE_TYPE = {
    'R': 0,
    'C': 1,
    '+gm+':2,
    '-gm+':3,
    '+gm-':4,
    '-gm-':5,
    'sudo_in':6,
    'sudo_out':7,
    'In': 8,
    'Out':9
}

SUBG_NODE = {
    0: ['In'],
    1: ['Out'],
    2: ['R'],
    3: ['C'],
    4: ['R','C'],
    5: ['R','C'],
    6: ['+gm+'],
    7: ['-gm+'],
    8: ['+gm-'],
    9: ['-gm-'],
    10: ['C', '+gm+'],
    11: ['C', '-gm+'],
    12: ['C', '+gm-'],
    13: ['C', '-gm-'],
    14: ['R', '+gm+'],
    15: ['R', '-gm+'],
    16: ['R', '+gm-'],
    17: ['R', '-gm-'],
    18: ['C', 'R', '+gm+'],
    19: ['C', 'R', '-gm+'],
    20: ['C', 'R', '+gm-'],
    21: ['C', 'R', '-gm-'],
    22: ['C', 'R', '+gm+'],
    23: ['C', 'R', '-gm+'],
    24: ['C', 'R', '+gm-'],
    25: ['C', 'R', '-gm-']
}

SUBG_CON = {
    0: None,
    1: None,
    2: None,
    3: None,
    4: 'series',
    5: 'parral',
    6: None,
    7: None,
    8: None,
    9: None,
    10: 'parral',
    11: 'parral',
    12: 'parral',
    13: 'parral',
    14: 'parral',
    15: 'parral',
    16: 'parral',
    17: 'parral',
    18: 'parral',
    19: 'parral',
    20: 'parral',
    21: 'parral',
    22: 'series',
    23: 'series',
    24: 'series',
    25: 'series'
}

SUBG_INDI = {0: [],
 1: [],
 2: [0],
 3: [1],
 4: [0, 1],
 5: [0, 1],
 6: [2],
 7: [2],
 8: [2],
 9: [2],
 10: [1, 2],
 11: [1, 2],
 12: [1, 2],
 13: [1, 2],
 14: [0, 2],
 15: [0, 2],
 16: [0, 2],
 17: [0, 2],
 18: [1, 0, 2],
 19: [1, 0, 2],
 20: [1, 0, 2],
 21: [1, 0, 2],
 22: [1, 0, 2],
 23: [1, 0, 2],
 24: [1, 0, 2],
 25: [1, 0, 2]
 }


def one_hot(idx, length):
    if type(idx) in [list, range]:
        if idx == []:
            return None
        idx = torch.LongTensor(idx).unsqueeze(0).t()
        x = torch.zeros((len(idx), length)).scatter_(1, idx, 1)
    else:
        idx = torch.LongTensor([idx]).unsqueeze(0)
        x = torch.zeros((1, length)).scatter_(1, idx, 1)
    return x

def inverse_adj(adj, device):
    n_node = adj.size(0)
    aug_adj = adj + torch.diag(torch.ones(n_node).to(device))
    aug_diag = torch.sum(aug_adj, dim=0)
    return torch.diag(1/aug_diag)



class subcGNN_dis(nn.Module):
    def __init__(self, num_cat, out_feat, num_feat=102, dropout=0.5, num_layer=2, readout='sum', device=None):
        super(subcGNN_dis, self).__init__()
        self.catag_lin = nn.Linear(num_cat, out_feat)
        self.numer_lin = nn.Linear(num_feat, out_feat)
        self.layers = nn.ModuleList()
        self.emb_dim = 2 * out_feat
        self.dropout = dropout
        self.num_cat = num_cat
        self.num_feat = num_feat
        #linlayer = nn.Linear(self.emb_dim, self.emb_dim)
        #act = nn.ReLu()
        #self.layers.append(linlayer)
        #self.layers.append(act)
        #if self.dropout > 0.0001:
        #    drop = nn.Dropout(dropout)
        #    self.layers.append(drop)
        for i in range(num_layer):
            linlayer = nn.Linear(self.emb_dim, self.emb_dim)
            act = nn.ReLU()
            self.layers.append(linlayer)
            if self.dropout > 0.0001:
                drop = nn.Dropout(dropout)
                self.layers.append(drop)
            self.layers.append(act)
        self.device = device  
        self.readout = readout
        self.num_layer = num_layer
        
    def forward(self, G):
        # G is a batch of graphs
        nodes_list = [g.vcount() for g in G]
        num_graphs = len(nodes_list)
        num_nodes = sum(nodes_list)
        sub_nodes_types = []
        sub_nodes_feats = []
        num_subg_nodes = []
        for i in range(num_graphs):
            g = G[i]
            for j in range(nodes_list[i]):
                sub_nodes_types += g.vs[j]['subg_ntypes']
                sub_nodes_feats += g.vs[j]['subg_nfeats']
                num_subg_nodes.append(len(g.vs[j]['subg_ntypes']))
        all_nodes = sum(num_subg_nodes)
        all_adj = torch.zeros(all_nodes,all_nodes)
        node_count = 0
        for i in range(num_graphs):
            g = G[i]
            for j in range(nodes_list[i]):
                adj_flat = g.vs[j]['subg_adj']
                subg_n = len(g.vs[j]['subg_ntypes'])
                all_adj[node_count:node_count+subg_n, node_count:node_count+subg_n] = torch.FloatTensor(adj_flat).reshape(subg_n,subg_n)
                node_count += subg_n
        all_adj = all_adj.to(self.get_device())
        in_categ = self._one_hot(sub_nodes_types,self.num_cat)
        in_numer = self._one_hot(sub_nodes_feats, self.num_feat)
        #in_numer = torch.FloatTensor(sub_nodes_feats).to(self.get_device()).unsqueeze(0).t()
        #print(in_categ)
        #print(in_numer)
        #print(all_adj)
        in_categ = self.catag_lin(in_categ)
        in_numer = self.numer_lin(in_numer)
        x = torch.cat([in_categ, in_numer], dim=1)
        inv_deg = inverse_adj(all_adj, self.get_device())
        #print(in_categ)
        #print(in_numer)
        #print(inv_deg)
        if self.dropout > 0.0001:
            for i in range(self.num_layer-1):
                x = self.layers[3 * i](x)
                x = x + torch.matmul(all_adj, x)
                x = torch.matmul(inv_deg, x)
                x = self.layers[3 * i + 1](x)
                x = self.layers[3 * i + 2](x)
            x = self.layers[3 * (i+1)](x)
            x = x + torch.matmul(all_adj, x)
            x = torch.matmul(inv_deg, x)
            x = self.layers[3 * (i+1) + 2](x)
        else:
            for i in range(self.num_layer-1):
                x = self.layers[2 * i](x)
                x = x + torch.matmul(all_adj, x)
                x = torch.matmul(inv_deg, x)
                x = self.layers[2 * i + 1](x)
                #x = self.layers[3 * i + 2](x)
            x = self.layers[2 * (i+1)](x)
            x = x + torch.matmul(all_adj, x)
            x = torch.matmul(inv_deg, x)
            x = self.layers[2 * (i+1) + 1](x)
        # readout phase
        #out = torch.zeros(num_nodes, self.emb_dim).to(self.get_device())
        out = x
        node_count = 0
        new_G = []
        for i in range(num_graphs):
            g = G[i].copy()
            for j in range(nodes_list[i]):
                subg_n = len(g.vs[j]['subg_ntypes'])
                subg_represent = out[node_count:node_count+subg_n, :]
                if self.readout == 'sum':
                    subg_feat = torch.sum(subg_represent, dim=0)
                elif self.readout == 'mean':
                    subg_feat = torch.mean(subg_represent, dim=0)
                else:
                    subg_feat = None
                    raise MyException('Undefined pool method')
                g.vs[j]['subg_feat'] = subg_feat
                node_count += subg_n
            new_G.append(g)
        return new_G
        
    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

class subc_GNN(nn.Module):
    def __init__(self, num_cat, out_feat, dropout=0.5, num_layer=2, readout='sum', device=None):
        super(subc_GNN, self).__init__()
        self.catag_lin = nn.Linear(num_cat, out_feat)
        self.numer_lin = nn.Linear(1, out_feat)
        self.layers = nn.ModuleList()
        self.emb_dim = 2 * out_feat
        self.dropout = dropout
        self.num_cat = num_cat
        #linlayer = nn.Linear(self.emb_dim, self.emb_dim)
        #act = nn.ReLu()
        #self.layers.append(linlayer)
        #self.layers.append(act)
        #if self.dropout > 0.0001:
        #    drop = nn.Dropout(dropout)
        #    self.layers.append(drop)
        for i in range(num_layer):
            linlayer = nn.Linear(self.emb_dim, self.emb_dim)
            act = nn.ReLU()
            self.layers.append(linlayer)
            if self.dropout > 0.0001:
                drop = nn.Dropout(dropout)
                self.layers.append(drop)
            self.layers.append(act)
        self.device = device  
        self.readout = readout
        self.num_layer = num_layer
        
    def forward(self, G):
        # G is a batch of graphs
        nodes_list = [g.vcount() for g in G]
        num_graphs = len(nodes_list)
        num_nodes = sum(nodes_list)
        sub_nodes_types = []
        sub_nodes_feats = []
        num_subg_nodes = []
        for i in range(num_graphs):
            g = G[i]
            for j in range(nodes_list[i]):
                sub_nodes_types += g.vs[j]['subg_ntypes']
                sub_nodes_feats += g.vs[j]['subg_nfeats']
                num_subg_nodes.append(len(g.vs[j]['subg_ntypes']))
        all_nodes = sum(num_subg_nodes)
        all_adj = torch.zeros(all_nodes,all_nodes)
        node_count = 0
        for i in range(num_graphs):
            g = G[i]
            for j in range(nodes_list[i]):
                adj_flat = g.vs[j]['subg_adj']
                subg_n = len(g.vs[j]['subg_ntypes'])
                all_adj[node_count:node_count+subg_n, node_count:node_count+subg_n] = torch.FloatTensor(adj_flat).reshape(subg_n,subg_n)
                node_count += subg_n
        all_adj = all_adj.to(self.get_device())
        in_categ = self._one_hot(sub_nodes_types,self.num_cat)
        in_numer = torch.FloatTensor(sub_nodes_feats).to(self.get_device()).unsqueeze(0).t()
        #print(in_categ)
        #print(in_numer)
        #print(all_adj)
        in_categ = self.catag_lin(in_categ)
        in_numer = self.numer_lin(in_numer)
        x = torch.cat([in_categ, in_numer], dim=1)
        inv_deg = inverse_adj(all_adj, self.get_device())
        #print(in_categ)
        #print(in_numer)
        #print(inv_deg)
        if self.dropout > 0.0001:
            for i in range(self.num_layer-1):
                x = self.layers[3 * i](x)
                x = x + torch.matmul(all_adj, x)
                x = torch.matmul(inv_deg, x)
                x = self.layers[3 * i + 1](x)
                x = self.layers[3 * i + 2](x)
            x = self.layers[3 * (i+1)](x)
            x = x + torch.matmul(all_adj, x)
            x = torch.matmul(inv_deg, x)
            x = self.layers[3 * (i+1) + 2](x)
        else:
            for i in range(self.num_layer-1):
                x = self.layers[2 * i](x)
                x = x + torch.matmul(all_adj, x)
                x = torch.matmul(inv_deg, x)
                x = self.layers[2 * i + 1](x)
                #x = self.layers[3 * i + 2](x)
            x = self.layers[2 * (i+1)](x)
            x = x + torch.matmul(all_adj, x)
            x = torch.matmul(inv_deg, x)
            x = self.layers[2 * (i+1) + 1](x)
        # readout phase
        #out = torch.zeros(num_nodes, self.emb_dim).to(self.get_device())
        out = x
        node_count = 0
        new_G = []
        for i in range(num_graphs):
            g = G[i].copy()
            for j in range(nodes_list[i]):
                subg_n = len(g.vs[j]['subg_ntypes'])
                subg_represent = out[node_count:node_count+subg_n, :]
                if self.readout == 'sum':
                    subg_feat = torch.sum(subg_represent, dim=0)
                elif self.readout == 'mean':
                    subg_feat = torch.mean(subg_represent, dim=0)
                else:
                    subg_feat = None
                    raise MyException('Undefined pool method')
                g.vs[j]['subg_feat'] = subg_feat
                node_count += subg_n
            new_G.append(g)
        return new_G
        
    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x 

class subc_GNN(nn.Module):
    def __init__(self, num_cat, out_feat, dropout=0.5, num_layer=2, readout='sum', device=None):
        super(subc_GNN, self).__init__()
        self.catag_lin = nn.Linear(num_cat, out_feat)
        self.numer_lin = nn.Linear(1, out_feat)
        self.layers = nn.ModuleList()
        self.emb_dim = 2 * out_feat
        self.dropout = dropout
        self.num_cat = num_cat
        #linlayer = nn.Linear(self.emb_dim, self.emb_dim)
        #act = nn.ReLu()
        #self.layers.append(linlayer)
        #self.layers.append(act)
        #if self.dropout > 0.0001:
        #    drop = nn.Dropout(dropout)
        #    self.layers.append(drop)
        for i in range(num_layer):
            linlayer = nn.Linear(self.emb_dim, self.emb_dim)
            act = nn.ReLU()
            self.layers.append(linlayer)
            if self.dropout > 0.0001:
                drop = nn.Dropout(dropout)
                self.layers.append(drop)
            self.layers.append(act)
        self.device = device  
        self.readout = readout
        self.num_layer = num_layer
        
    def forward(self, G):
        # G is a batch of graphs
        nodes_list = [g.vcount() for g in G]
        num_graphs = len(nodes_list)
        num_nodes = sum(nodes_list)
        sub_nodes_types = []
        sub_nodes_feats = []
        num_subg_nodes = []
        for i in range(num_graphs):
            g = G[i]
            for j in range(nodes_list[i]):
                sub_nodes_types += g.vs[j]['subg_ntypes']
                sub_nodes_feats += g.vs[j]['subg_nfeats']
                num_subg_nodes.append(len(g.vs[j]['subg_ntypes']))
        all_nodes = sum(num_subg_nodes)
        all_adj = torch.zeros(all_nodes,all_nodes)
        node_count = 0
        for i in range(num_graphs):
            g = G[i]
            for j in range(nodes_list[i]):
                adj_flat = g.vs[j]['subg_adj']
                subg_n = len(g.vs[j]['subg_ntypes'])
                all_adj[node_count:node_count+subg_n, node_count:node_count+subg_n] = torch.FloatTensor(adj_flat).reshape(subg_n,subg_n)
                node_count += subg_n
        all_adj = all_adj.to(self.get_device())
        in_categ = self._one_hot(sub_nodes_types,self.num_cat)
        in_numer = torch.FloatTensor(sub_nodes_feats).to(self.get_device()).unsqueeze(0).t()
        #print(in_categ)
        #print(in_numer)
        #print(all_adj)
        in_categ = self.catag_lin(in_categ)
        in_numer = self.numer_lin(in_numer)
        x = torch.cat([in_categ, in_numer], dim=1)
        inv_deg = inverse_adj(all_adj)
        #print(in_categ)
        #print(in_numer)
        #print(inv_deg)
        if self.dropout > 0.0001:
            for i in range(self.num_layer-1):
                x = self.layers[3 * i](x)
                x = x + torch.matmul(all_adj, x)
                x = torch.matmul(inv_deg, x)
                x = self.layers[3 * i + 1](x)
                x = self.layers[3 * i + 2](x)
            x = self.layers[3 * (i+1)](x)
            x = x + torch.matmul(all_adj, x)
            x = torch.matmul(inv_deg, x)
            x = self.layers[3 * (i+1) + 2](x)
        else:
            for i in range(self.num_layer-1):
                x = self.layers[2 * i](x)
                x = x + torch.matmul(all_adj, x)
                x = torch.matmul(inv_deg, x)
                x = self.layers[2 * i + 1](x)
                #x = self.layers[3 * i + 2](x)
            x = self.layers[2 * (i+1)](x)
            x = x + torch.matmul(all_adj, x)
            x = torch.matmul(inv_deg, x)
            x = self.layers[2 * (i+1) + 1](x)
        # readout phase
        #out = torch.zeros(num_nodes, self.emb_dim).to(self.get_device())
        out = x
        node_count = 0
        new_G = []
        for i in range(num_graphs):
            g = G[i].copy()
            for j in range(nodes_list[i]):
                subg_n = len(g.vs[j]['subg_ntypes'])
                subg_represent = out[node_count:node_count+subg_n, :]
                if self.readout == 'sum':
                    subg_feat = torch.sum(subg_represent, dim=0)
                elif self.readout == 'mean':
                    subg_feat = torch.mean(subg_represent, dim=0)
                else:
                    subg_feat = None
                    raise MyException('Undefined pool method')
                g.vs[j]['subg_feat'] = subg_feat
                node_count += subg_n
            new_G.append(g)
        return new_G
        
    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x 

def subg_loss(reg_vals_list,G_true,v_true, device=None, subg_indi=SUBG_INDI):
    res = 0
    for idx, g in enumerate(G_true):
        if g.vcount() > v_true:
            true_type = g.vs[v_true]['type']
            if true_type >= 2:
                pred_val = reg_vals_list[idx][torch.LongTensor(subg_indi[true_type]).to(device)]
                true_val = torch.FloatTensor(g.vs[v_true]['subg_nfeats']).to(device)
                val = F.mse_loss(pred_val, true_val, reduction='mean')
                res += val
    return res

def subg_loss_dis(reg_vals_list,G_true,v_true, device=None, subg_indi=SUBG_INDI):
    # reg_vals_list: for v_true [scores for g in G_true]
    logsoftmax1 = nn.LogSoftmax(1)
    res = 0
    #vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()
    for idx, g in enumerate(G_true):
        if g.vcount() > v_true:
            true_type = g.vs[v_true]['type']
            if true_type >= 2:
                n_node = len(subg_indi[true_type])
                pred_score = reg_vals_list[idx]
                true_types = torch.LongTensor(g.vs[v_true]['subg_nfeats']).to(device)[1:-1]
                val = logsoftmax1(pred_score)[torch.LongTensor(subg_indi[true_type]).to(device), true_types].sum()
                res += val
    return res

def subn_loss(reg_vals_list,G_true,v_true, device=None):
    # reg_vals_list: for v_true [scores for g in G_true]
    logsoftmax1 = nn.LogSoftmax(1)
    res = 0
    #vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()
    for idx, g in enumerate(G_true):
        if g.vcount() > v_true:
            true_type = g.vs[v_true]['type']
            #n_node = len(subg_indi[true_type])
            pred_score = reg_vals_list[idx]
            true_types = torch.LongTensor([g.vs[v_true]['r'], g.vs[v_true]['c'], g.vs[v_true]['gm']]).to(device)
            #print(pred_score.shape)
            #print(true_types)
            val = logsoftmax1(pred_score)[np.arange(3), true_types].sum()
            res += val
    return res

# Circuit-VAE (CVAE) continuous version and general version
class CVAE_dec(nn.Module):
    def __init__(self, max_n, nvt, subg_nvt, subn_nvt ,START_TYPE, END_TYPE, emb_dim = 128, hs=301, nz=56, bidirectional=False, vid=True):
        super(CVAE_dec, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.nvt = nvt  # number of vertex types
        self.subg_nvt = subg_nvt # number of nodes type in the subg
        self.subn_nvt = subn_nvt # number of value type of each node in subgraphs
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.emb_dim = emb_dim
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None

        if self.vid:
            self.vs = hs + max_n + emb_dim  # vertex state size = hidden state + vid
        else:
            self.vs = hs + emb_dim 

        # 0. encoding-related
        self.grue_forward = nn.GRUCell(nvt, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt, hs)  # backward encoder GRU
        self.subgnn = subcGNN_dis(num_cat = self.subg_nvt, out_feat = int(self.emb_dim/2), num_feat = self.subn_nvt,
                                  dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        self.grud = nn.GRUCell(nvt, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.subgnn_decode = subcGNN_dis(num_cat = self.subg_nvt, out_feat = int(self.emb_dim/2), num_feat = self.subn_nvt,
                                  dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.fc_r = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, subn_nvt)
                )  # Regression layer for r: take the hidden representation and type score as input
        self.fc_c = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, subn_nvt)
                )  # Regression layer for r
        self.fc_gm = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, subn_nvt)
                )  # Regression layer for r
        self.regs = [self.fc_r, self.fc_c, self.fc_gm]
        
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False, decode=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        X = self._one_hot(v_types, self.nvt)
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            if not decode:
                F_pred = [[g.vs[x]['subg_feat'].unsqueeze(0) for x in g.successors(v)] for g in G]
            else:
                F_pred = [[torch.zeros(1,self.emb_dim).to(self.get_device()) for x in g.successors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
                #print(v)
                #print(vids)
            if not decode:
                F_pred = [[g.vs[x]['subg_feat'].unsqueeze(0) for x in g.predecessors(v)] for g in G]
            else:
                F_pred = [[torch.zeros(1,self.emb_dim).to(self.get_device()) for x in g.predecessors(v)] for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        if self.vid:
            #print(H_pred)
            #print(vids)
            #print(F_pred)
            #H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
            H_pred = [[torch.cat([x[i], z[i], y[i:i+1]], 1) for i in range(len(x))] for x, y, z in zip(H_pred, vids, F_pred)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
            #print(g.vs[v][H_name].shape)
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False, decode=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse, decode=decode)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse, decode=decode)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None, decode=False):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False, decode=decode)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        G = self.subgnn(G)
        #return G
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False, decode=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True, decode=False)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True, node_type_dic=NODE_TYPE, subg_node=SUBG_NODE, subg_con=SUBG_CON, subg_indi=SUBG_INDI):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['subg_ntypes'] = [8]
            g.vs[0]['subg_nfeats'] = [0.0]
            g.vs[0]['subg_adj'] = [1]
        G = self.subgnn_decode(G)
        self._update_v(G, 0, H0) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
             
            # decide subtype information
            H = self._get_vertex_state(G, idx)
            H_reg = torch.cat([Hg,type_scores],dim=1) # H to Hg
            reg_vals = []
            for func in self.regs:
                subg_score = func(H_reg)
                reg_vals.append(subg_score)
                
            for j,g in enumerate(G):
                if not finished[j]:
                    g.add_vertex(type=new_types[j])
                    if new_types[j] == 0: 
                        pred_types = [8]
                        g_vals = [0.0]
                        pred_adj = [1]
                    elif new_types[j] == 1:
                        pred_types = [9]
                        g_vals = [0.0]
                        pred_adj = [1]
                    else:
                        # g_vals
                        g_vals = [0] 
                        g_val_ = []
                        for reg_v in reg_vals:
                            subn_scores = reg_v[j,:]
                            if stochastic:
                                type_prob = F.softmax(subn_scores, dim=0).cpu().detach().numpy()
                                #print(type_prob)
                                new_val = np.random.choice(range(self.subn_nvt), p=type_prob) + 1
                            else:  
                                new_val = torch.argmax(subn_scores, dim=0).tolist() + 1
                            g_val_.append(new_val)
                        for feat_id in subg_indi[new_types[j]]:
                            g_vals.append(g_val_[feat_id])
                        g_vals.append(0)
                        # pred_types
                        pred_types = [6]
                        for ty in subg_node[new_types[j]]:
                            pred_types.append(node_type_dic[ty])
                        pred_types.append(7)
                        
                        # pred_adj
                        pred_adj = subg_flaten_adj(len(subg_node[new_types[j]]), subg_con[new_types[j]])
                    """
                    if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
                    if true_types[i] == 1:
                        true_subg_type = 
                        true_subg_feat = []
                        true_subg_adj = []
                    g.vs[v_true][]
                    """
                    g.vs[idx]['subg_ntypes'] = pred_types
                    g.vs[idx]['subg_nfeats'] = g_vals
                    g.vs[idx]['subg_adj'] = pred_adj  
                    #print(pred_types)
                    #print(g_vals)
            
            G = self.subgnn_decode(G)
            self._update_v(G, idx)
            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005, reg_scale=0.01):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['subg_ntypes'] = [8]
            g.vs[0]['subg_nfeats'] = [0.0]
            g.vs[0]['subg_adj'] = [1]
        G = self.subgnn_decode(G)
        self._update_v(G, 0, H0)
        res = 0  # log likelihood
        for v_true in range(1, self.max_n):
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            Hg = self._get_graph_state(G, decode=True) 
            
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()  
            res = res + vll
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
                    g.vs[v_true]['subg_ntypes'] = G_true[i].vs[v_true]['subg_ntypes']
                    g.vs[v_true]['subg_nfeats'] = G_true[i].vs[v_true]['subg_nfeats']
                    g.vs[v_true]['subg_adj'] = G_true[i].vs[v_true]['subg_adj']
            G = self.subgnn_decode(G)     
            self._update_v(G, v_true)
            # calculate the mse loss of asubg nodes value
            H = self._get_vertex_state(G, v_true)
            H_reg = torch.cat([H, type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(H_reg)
                reg_vals.append(subg_score) 
            reg_vals_list = []
            for i in range(len(G_true)):
                reg_vals_list.append(torch.cat([val[i].unsqueeze(0) for val in reg_vals], dim=0))
                #print(reg_vals_list[i].shape)
            vl2 = subg_loss_dis(reg_vals_list,G_true,v_true,device=self.get_device())  ######   
            res += vl2
            #res_mse += reg_scale * vl2
            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true, decode=True)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        #res += res_mse
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G

class CVAE_simple(nn.Module):
    def __init__(self, max_n, nvt, subn_nvt ,START_TYPE, END_TYPE, emb_dim = 64, hs=301, nz=56, bidirectional=False, vid=True):
        super(CVAE_simple, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.nvt = nvt  # number of vertex types 
        self.subn_nvt = subn_nvt + 1 # number of value type of each node in subgraphs
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.emb_dim = emb_dim
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None
        
        #
        self.vs = hs 

        # 0. encoding-related
        self.feat_enc = nn.Sequential(
                nn.Linear(3 * self.subn_nvt, emb_dim * 2),
                nn.ReLU(),
                nn.Linear(emb_dim * 2,  emb_dim)
                )
        self.grue_forward = nn.GRUCell(nvt + emb_dim + max_n, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt + emb_dim + max_n, hs)  # backward encoder GRU
        #self.subgnn = subcGNN_dis(num_cat = self.subg_nvt, out_feat = int(self.emb_dim/2), num_feat = self.subn_nvt,
        #                          dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        self.feat_dec = nn.Sequential(
                nn.Linear(3 * self.subn_nvt, emb_dim * 2),
                nn.ReLU(),
                nn.Linear(emb_dim * 2,  emb_dim)
                )
        self.grud = nn.GRUCell(nvt + emb_dim + max_n, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.vid_fc = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, max_n)
                )
        self.fc_r = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  self.subn_nvt)
                )  # Regression layer for r: take the hidden representation and type score as input
        self.fc_c = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  self.subn_nvt)
                )  # Regression layer for r
        self.fc_gm = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  self.subn_nvt)
                )  # Regression layer for r
        self.regs = [self.fc_r, self.fc_c, self.fc_gm]
        
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False, decode=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        r_feats = [g.vs[v]['r'] for g in G]
        c_feats = [g.vs[v]['c'] for g in G]
        gm_feats = [g.vs[v]['gm'] for g in G]
        vid_feats = [g.vs[v]['vid'] for g in G]
        X_v = self._one_hot(v_types, self.nvt)
        X_r = self._one_hot(r_feats, self.subn_nvt)
        X_c = self._one_hot(c_feats, self.subn_nvt)
        X_gm = self._one_hot(gm_feats, self.subn_nvt)
        X_vid = self._one_hot(vid_feats, self.max_n)
        X_feat = self.feat_enc(torch.cat([X_r, X_c, X_gm], dim=1))
        X = torch.cat([X_v, X_feat, X_vid], dim=1)
        
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            #if self.vid:
            #    vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            #if self.vid:
            #    vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        #if self.vid:
            #H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
            #H_pred = [[torch.cat([x[i], z[i], y[i:i+1]], 1) for i in range(len(x))] for x, y, z in zip(H_pred, vids, F_pred)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
            #print(g.vs[v][H_name].shape)
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False, decode=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse, decode=decode)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse, decode=decode)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None, decode=False):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False, decode=decode)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        #return G
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False, decode=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True, decode=False)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True, node_type_dic=NODE_TYPE, subg_node=SUBG_NODE, subg_con=SUBG_CON, subg_indi=SUBG_INDI):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['r'] = 0
            g.vs[0]['c'] = 0
            g.vs[0]['gm'] = 0
            g.vs[0]['vid'] = 0
        self._update_v(G, 0, H0) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                vid_scores = self.vid_fc(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    vid_probs = F.softmax(vid_scores, 1).cpu().detach().numpy() 
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                    new_vids = [np.random.choice(range(self.max_n), p=vid_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
                    new_vids = torch.argmax(vid_scores, 1)
                    new_vids = new_vids.flatten().tolist()
             
            # decide subtype information
            H = self._get_vertex_state(G, idx)
            pred_vals = []
            for func in self.regs:
                subg_score = func(Hg)
                pred_vals.append(subg_score)
                
            for j,g in enumerate(G):
                if not finished[j]:
                    g.add_vertex(type=new_types[j])
                    g.vs[idx]['vid'] = new_vids[j]
                    g_val_ = []
                    for reg_v in pred_vals:
                        subn_scores = reg_v[j,:]
                        if stochastic:
                            type_prob = F.softmax(subn_scores, dim=0).cpu().detach().numpy()
                            #print(type_prob)
                            new_val = np.random.choice(range(0, self.subn_nvt), p=type_prob)
                        else:  
                            new_val = torch.argmax(subn_scores, dim=0).tolist()
                        g_val_.append(new_val)
                    g.vs[idx]['r'] = int(g_val_[0])
                    g.vs[idx]['c'] = int(g_val_[1])
                    g.vs[idx]['gm'] = int(g_val_[2])
            
            self._update_v(G, idx)
            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005, reg_scale=0.1):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['r'] = 0
            g.vs[0]['c'] = 0
            g.vs[0]['gm'] = 0
            g.vs[0]['vid'] = 0
        self._update_v(G, 0, H0)
        res = 0  # log likelihood
        for v_true in range(1, self.max_n):
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            true_vids = [g_true.vs[v_true]['vid'] if v_true < g_true.vcount()  # (bsize, 1)
                          else 1 for g_true in G_true]
            Hg = self._get_graph_state(G, decode=True) 
            
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            vid_scores = self.vid_fc(Hg)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum() 
            vl3 = self.logsoftmax1(vid_scores)[np.arange(len(G)), true_vids].sum() 
            res = res + vll + vl3
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
                    g.vs[v_true]['r'] = G_true[i].vs[v_true]['r']
                    g.vs[v_true]['c'] = G_true[i].vs[v_true]['c']
                    g.vs[v_true]['gm'] = G_true[i].vs[v_true]['gm']
                    g.vs[v_true]['vid'] = G_true[i].vs[v_true]['vid']
            self._update_v(G, v_true)
            # calculate the mse loss of asubg nodes value
            H = self._get_vertex_state(G, v_true)
            #H_reg = torch.cat([H, type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(Hg)
                reg_vals.append(subg_score) 
            reg_vals_list = []
            for i in range(len(G_true)):
                reg_vals_list.append(torch.cat([val[i].unsqueeze(0) for val in reg_vals], dim=0))
                #print(reg_vals_list[i].shape)
            vl2 = subn_loss(reg_vals_list,G_true,v_true,device=self.get_device())  ######   
            res += vl2
            #res_mse += reg_scale * vl2
            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true, decode=True)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        #res += res_mse
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G


class CVAE_simple_v2(nn.Module):
    def __init__(self, max_n, nvt, subn_nvt ,START_TYPE, END_TYPE, emb_dim = 64, hs=301, nz=56, bidirectional=False, vid=True):
        super(CVAE_simple_v2, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.nvt = nvt  # number of vertex types 
        self.subn_nvt = subn_nvt + 1 # number of value type of each node in subgraphs
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.emb_dim = emb_dim
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None
        
        #
        self.vs = hs 

        # 0. encoding-related
        self.feat_enc = nn.Sequential(
                nn.Linear(3 * self.subn_nvt, emb_dim * 2),
                nn.ReLU(),
                nn.Linear(emb_dim * 2,  emb_dim)
                )
        self.grue_forward = nn.GRUCell(nvt + emb_dim + max_n, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt + emb_dim + max_n, hs)  # backward encoder GRU
        #self.subgnn = subcGNN_dis(num_cat = self.subg_nvt, out_feat = int(self.emb_dim/2), num_feat = self.subn_nvt,
        #                          dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        self.feat_dec = nn.Sequential(
                nn.Linear(3 * self.subn_nvt, emb_dim * 2),
                nn.ReLU(),
                nn.Linear(emb_dim * 2,  emb_dim)
                )
        self.grud = nn.GRUCell(nvt + max_n, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.vid_fc = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, max_n)
                )
        self.fc_r = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  self.subn_nvt)
                )  # Regression layer for r: take the hidden representation and type score as input
        self.fc_c = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  self.subn_nvt)
                )  # Regression layer for r
        self.fc_gm = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  self.subn_nvt)
                )  # Regression layer for r
        self.regs = [self.fc_r, self.fc_c, self.fc_gm]
        
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False, decode=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        r_feats = [g.vs[v]['r'] for g in G]
        c_feats = [g.vs[v]['c'] for g in G]
        gm_feats = [g.vs[v]['gm'] for g in G]
        vid_feats = [g.vs[v]['vid'] for g in G]
        X_v = self._one_hot(v_types, self.nvt)
        X_r = self._one_hot(r_feats, self.subn_nvt)
        X_c = self._one_hot(c_feats, self.subn_nvt)
        X_gm = self._one_hot(gm_feats, self.subn_nvt)
        X_vid = self._one_hot(vid_feats, self.max_n)
        X_feat = self.feat_enc(torch.cat([X_r, X_c, X_gm], dim=1))
        if decode:
            X = torch.cat([X_v, X_vid], dim=1)
        else:
            X = torch.cat([X_v, X_feat, X_vid], dim=1)
        
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            #if self.vid:
            #    vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            #if self.vid:
            #    vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        #if self.vid:
            #H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
            #H_pred = [[torch.cat([x[i], z[i], y[i:i+1]], 1) for i in range(len(x))] for x, y, z in zip(H_pred, vids, F_pred)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
            #print(g.vs[v][H_name].shape)
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False, decode=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse, decode=decode)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse, decode=decode)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None, decode=False):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False, decode=decode)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        #return G
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False, decode=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True, decode=False)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True, node_type_dic=NODE_TYPE, subg_node=SUBG_NODE, subg_con=SUBG_CON, subg_indi=SUBG_INDI):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['r'] = 0
            g.vs[0]['c'] = 0
            g.vs[0]['gm'] = 0
            g.vs[0]['vid'] = 0
        self._update_v(G, 0, H0, decode=True) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                vid_scores = self.vid_fc(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    vid_probs = F.softmax(vid_scores, 1).cpu().detach().numpy() 
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                    new_vids = [np.random.choice(range(self.max_n), p=vid_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
                    new_vids = torch.argmax(vid_scores, 1)
                    new_vids = new_vids.flatten().tolist()
             
            # decide subtype information
            H = self._get_vertex_state(G, idx)
            pred_vals = []
            for func in self.regs:
                subg_score = func(Hg)
                pred_vals.append(subg_score)
                
            for j,g in enumerate(G):
                if not finished[j]:
                    g.add_vertex(type=new_types[j])
                    g.vs[idx]['vid'] = new_vids[j]
                    g_val_ = []
                    for reg_v in pred_vals:
                        subn_scores = reg_v[j,:]
                        if stochastic:
                            type_prob = F.softmax(subn_scores, dim=0).cpu().detach().numpy()
                            #print(type_prob)
                            new_val = np.random.choice(range(0, self.subn_nvt), p=type_prob)
                        else:  
                            new_val = torch.argmax(subn_scores, dim=0).tolist()
                        g_val_.append(new_val)
                    g.vs[idx]['r'] = int(g_val_[0])
                    g.vs[idx]['c'] = int(g_val_[1])
                    g.vs[idx]['gm'] = int(g_val_[2])
            
            self._update_v(G, idx,decode=True)
            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx, decode=True)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005, reg_scale=0.05, pos_scale=0.5):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['r'] = 0
            g.vs[0]['c'] = 0
            g.vs[0]['gm'] = 0
            g.vs[0]['vid'] = 0
        self._update_v(G, 0, H0, decode=True)
        res = 0  # log likelihood
        for v_true in range(1, self.max_n):
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            true_vids = [g_true.vs[v_true]['vid'] if v_true < g_true.vcount()  # (bsize, 1)
                          else 1 for g_true in G_true]
            Hg = self._get_graph_state(G, decode=True) 
            
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            vid_scores = self.vid_fc(Hg)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum() 
            vl3 = self.logsoftmax1(vid_scores)[np.arange(len(G)), true_vids].sum() 
            res = res + vll + vl3
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
                    g.vs[v_true]['r'] = G_true[i].vs[v_true]['r']
                    g.vs[v_true]['c'] = G_true[i].vs[v_true]['c']
                    g.vs[v_true]['gm'] = G_true[i].vs[v_true]['gm']
                    g.vs[v_true]['vid'] = G_true[i].vs[v_true]['vid']
            self._update_v(G, v_true,decode=True)
            # calculate the mse loss of asubg nodes value
            H = self._get_vertex_state(G, v_true)
            #H_reg = torch.cat([H, type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(Hg)
                reg_vals.append(subg_score) 
            reg_vals_list = []
            for i in range(len(G_true)):
                reg_vals_list.append(torch.cat([val[i].unsqueeze(0) for val in reg_vals], dim=0))
                #print(reg_vals_list[i].shape)
            vl2 = subn_loss(reg_vals_list,G_true,v_true,device=self.get_device())  ######   
            res += vl2
            #res_mse += reg_scale * vl2
            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true, decode=True)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        #res += res_mse
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G

class CVAE_simple2(nn.Module):
    def __init__(self, max_n, nvt, subn_nvt ,START_TYPE, END_TYPE, emb_dim = 64, hs=301, nz=56, bidirectional=False, vid=True):
        super(CVAE_simple2, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.nvt = nvt  # number of vertex types
        self.subn_nvt = subn_nvt # number of value type of each node in subgraphs
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.emb_dim = emb_dim
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None

        if self.vid:
            self.vs = hs + max_n   # vertex state size = hidden state + vid
        else:
            self.vs = hs 

        # 0. encoding-related
        #self.feat_enc = nn.Sequential(
        #        nn.Linear(3 * subn_nvt, emb_dim * 2),
        #        nn.ReLU(),
        #        nn.Linear(emb_dim * 2,  emb_dim)
        #        )
        self.grue_forward = nn.GRUCell(nvt + 3 * subn_nvt, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt + 3 * subn_nvt, hs)  # backward encoder GRU
        #self.subgnn = subcGNN_dis(num_cat = self.subg_nvt, out_feat = int(self.emb_dim/2), num_feat = self.subn_nvt,
        #                          dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        
        self.grud = nn.GRUCell(nvt +  3 * subn_nvt, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.fc_r = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  subn_nvt)
                )  # Regression layer for r: take the hidden representation and type score as input
        self.fc_c = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  subn_nvt)
                )  # Regression layer for r
        self.fc_gm = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2,  subn_nvt)
                )  # Regression layer for r
        self.regs = [self.fc_r, self.fc_c, self.fc_gm]
        
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False, decode=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        r_feats = [g.vs[v]['r'] for g in G]
        c_feats = [g.vs[v]['c'] for g in G]
        gm_feats = [g.vs[v]['gm'] for g in G]
        X_v = self._one_hot(v_types, self.nvt)
        X_r = self._one_hot(r_feats, self.subn_nvt)
        X_c = self._one_hot(c_feats, self.subn_nvt)
        X_gm = self._one_hot(gm_feats, self.subn_nvt)
        #X_feat = self.feat_enc(torch.cat([X_r, X_c, X_gm], dim=1))
        X = torch.cat([X_v, X_r, X_c, X_gm], dim=1)
        
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            if self.vid:
                vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        if self.vid:
            H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
            #H_pred = [[torch.cat([x[i], z[i], y[i:i+1]], 1) for i in range(len(x))] for x, y, z in zip(H_pred, vids, F_pred)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
            #print(g.vs[v][H_name].shape)
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False, decode=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse, decode=decode)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse, decode=decode)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None, decode=False):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False, decode=decode)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        #return G
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False, decode=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True, decode=False)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True, node_type_dic=NODE_TYPE, subg_node=SUBG_NODE, subg_con=SUBG_CON, subg_indi=SUBG_INDI):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['r'] = 0
            g.vs[0]['c'] = 0
            g.vs[0]['gm'] = 0
        self._update_v(G, 0, H0) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
             
            # decide subtype information
            H = self._get_vertex_state(G, idx)
            pred_vals = []
            for func in self.regs:
                subg_score = func(Hg)
                pred_vals.append(subg_score)
                
            for j,g in enumerate(G):
                if not finished[j]:
                    g.add_vertex(type=new_types[j])
                    g_val_ = []
                    for reg_v in pred_vals:
                        subn_scores = reg_v[j,:]
                        if stochastic:
                            type_prob = F.softmax(subn_scores, dim=0).cpu().detach().numpy()
                            #print(type_prob)
                            new_val = np.random.choice(range(0, self.subn_nvt), p=type_prob)
                        else:  
                            new_val = torch.argmax(subn_scores, dim=0).tolist()
                        g_val_.append(new_val)
                    g.vs[idx]['r'] = int(g_val_[0])
                    g.vs[idx]['c'] = int(g_val_[1])
                    g.vs[idx]['gm'] = int(g_val_[2])
            
            self._update_v(G, idx)
            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005, reg_scale=0.01):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['r'] = 0
            g.vs[0]['c'] = 0
            g.vs[0]['gm'] = 0
        self._update_v(G, 0, H0)
        res = 0  # log likelihood
        for v_true in range(1, self.max_n):
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            Hg = self._get_graph_state(G, decode=True) 
            
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()  
            res = res + vll
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
                    g.vs[v_true]['r'] = G_true[i].vs[v_true]['r']
                    g.vs[v_true]['c'] = G_true[i].vs[v_true]['c']
                    g.vs[v_true]['gm'] = G_true[i].vs[v_true]['gm']
                    
            self._update_v(G, v_true)
            # calculate the mse loss of asubg nodes value
            H = self._get_vertex_state(G, v_true)
            #H_reg = torch.cat([H, type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(Hg)
                reg_vals.append(subg_score) 
            reg_vals_list = []
            for i in range(len(G_true)):
                reg_vals_list.append(torch.cat([val[i].unsqueeze(0) for val in reg_vals], dim=0))
                #print(reg_vals_list[i].shape)
            vl2 = subn_loss(reg_vals_list,G_true,v_true,device=self.get_device())  ######   
            res += vl2
            #res_mse += reg_scale * vl2
            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true, decode=True)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        #res += res_mse
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G

class CVAE_conti(nn.Module):
    def __init__(self, max_n, nvt, subg_nvt, START_TYPE, END_TYPE, emb_dim = 128, hs=301, nz=56, bidirectional=False, vid=True):
        super(CVAE_conti, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.nvt = nvt  # number of vertex types
        self.subg_nvt = subg_nvt # number of nodes type in the subg
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.emb_dim = emb_dim
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None

        if self.vid:
            self.vs = hs + max_n + emb_dim  # vertex state size = hidden state + vid
        else:
            self.vs = hs + emb_dim 

        # 0. encoding-related
        self.grue_forward = nn.GRUCell(nvt, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt, hs)  # backward encoder GRU
        self.subgnn = subc_GNN(self.subg_nvt, int(self.emb_dim/2), dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        self.grud = nn.GRUCell(nvt, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.fc_r = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, 1)
                )  # Regression layer for r: take the hidden representation and type score as input
        self.fc_c = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, 1)
                )  # Regression layer for r
        self.fc_gm = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, 1)
                )  # Regression layer for r
        self.regs = [self.fc_r, self.fc_c, self.fc_gm]
        
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False, decode=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        X = self._one_hot(v_types, self.nvt)
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            if not decode:
                F_pred = [[g.vs[x]['subg_feat'].unsqueeze(0) for x in g.successors(v)] for g in G]
            else:
                F_pred = [[torch.zeros(1,self.emb_dim).to(self.get_device()) for x in g.successors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
                #print(v)
                #print(vids)
            if not decode:
                F_pred = [[g.vs[x]['subg_feat'].unsqueeze(0) for x in g.predecessors(v)] for g in G]
            else:
                F_pred = [[torch.zeros(1,self.emb_dim).to(self.get_device()) for x in g.predecessors(v)] for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        if self.vid:
            #print(H_pred)
            #print(vids)
            #print(F_pred)
            #H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
            H_pred = [[torch.cat([x[i], z[i], y[i:i+1]], 1) for i in range(len(x))] for x, y, z in zip(H_pred, vids, F_pred)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
            #print(g.vs[v][H_name].shape)
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False, decode=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse, decode=decode)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse, decode=decode)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None, decode=False):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False, decode=decode)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        G = self.subgnn(G)
        #return G
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False, decode=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True, decode=False)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
        self._update_v(G, 0, H0, decode=True) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
             
            #self._update_v(G, idx, decode=True)
            # decide subtype information
            H = self._get_vertex_state(G, idx)
            H_reg = torch.cat([H,type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(H_reg)
                reg_vals.append(subg_score)
            for j,g in enumerate(G):
                if not finished[j]:
                    g.add_vertex(type=new_types[j])
                    g_vals = [0] 
                    for reg_v in reg_vals:
                        subn_scores = reg_v[j,:]
                        if stochastic:
                            type_prob = F.softmax(subn_scores, dim=0).cpu().detach().numpy()
                            new_val = np.random.choice(range(self.subn_nvt), p=type_prob) + 1
                        else:  
                            new_val = torch.argmax(subn_scores, dim=0).tolist() + 1
                        g_vals.append(new_val)
                    g_vals.append(0)
                    g.vs[idx]['subg_nfeats'] = g_vals                    
            
            self._update_v(G, idx, decode=True)   
            
            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx, decode=True)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005, reg_scale=0.01):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
        self._update_v(G, 0, H0, decode=True)
        res = 0  # log likelihood
        res_mse = 0
        for v_true in range(1, self.max_n):
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            Hg = self._get_graph_state(G, decode=True) 
            
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()  
            res = res + vll
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
            self._update_v(G, v_true, decode=True)
            # calculate the mse loss of asubg nodes value
            H = self._get_vertex_state(G, v_true)
            H_reg = torch.cat([H, type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(H_reg)
                reg_vals.append(subg_score) 
            reg_vals_list = []
            for i in range(len(G_true)):
                val = []
                for reg_v in reg_vals:
                    val.append(reg_v.cpu().detach().numpy()[i,0])
                val = torch.FloatTensor(val).to(self.get_device())
                reg_vals_list.append(val)
            vl2 = subg_loss(reg_vals_list,G_true,v_true,device=self.get_device())  ######   
            res_mse += reg_scale * vl2
            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true, decode=True)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        res += res_mse
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G


class CVAE(nn.Module):
    def __init__(self, max_n, nvt, subg_nvt, subn_nvt ,START_TYPE, END_TYPE, emb_dim = 128, hs=301, nz=56, bidirectional=False, vid=True):
        super(CVAE, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.nvt = nvt  # number of vertex types
        self.subg_nvt = subg_nvt # number of nodes type in the subg
        self.subn_nvt = subn_nvt # number of value type of each node in subgraphs
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.emb_dim = emb_dim
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None

        if self.vid:
            self.vs = hs + max_n + emb_dim  # vertex state size = hidden state + vid
        else:
            self.vs = hs + emb_dim 

        # 0. encoding-related
        self.grue_forward = nn.GRUCell(nvt, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt, hs)  # backward encoder GRU
        self.subgnn = subcGNN_dis(num_cat = self.subg_nvt, out_feat = int(self.emb_dim/2), num_feat = self.subn_nvt,
                                  dropout=0.5, num_layer=2, readout='sum', device=self.device)
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        self.grud = nn.GRUCell(nvt, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.fc_r = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, subn_nvt)
                )  # Regression layer for r: take the hidden representation and type score as input
        self.fc_c = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, subn_nvt)
                )  # Regression layer for r
        self.fc_gm = nn.Sequential(
                nn.Linear(hs + self.nvt, hs),
                nn.ReLU(),
                nn.Linear(hs, subn_nvt)
                )  # Regression layer for r
        self.regs = [self.fc_r, self.fc_c, self.fc_gm]
        
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False, decode=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        X = self._one_hot(v_types, self.nvt)
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            if not decode:
                F_pred = [[g.vs[x]['subg_feat'].unsqueeze(0) for x in g.successors(v)] for g in G]
            else:
                F_pred = [[torch.zeros(1,self.emb_dim).to(self.get_device()) for x in g.successors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            if self.vid:
                vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
                #print(v)
                #print(vids)
            if not decode:
                F_pred = [[g.vs[x]['subg_feat'].unsqueeze(0) for x in g.predecessors(v)] for g in G]
            else:
                F_pred = [[torch.zeros(1,self.emb_dim).to(self.get_device()) for x in g.predecessors(v)] for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        if self.vid:
            #print(H_pred)
            #print(vids)
            #print(F_pred)
            #H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
            H_pred = [[torch.cat([x[i], z[i], y[i:i+1]], 1) for i in range(len(x))] for x, y, z in zip(H_pred, vids, F_pred)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
            #print(g.vs[v][H_name].shape)
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False, decode=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse, decode=decode)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse, decode=decode)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None, decode=False):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False, decode=decode)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        G = self.subgnn(G)
        #return G
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False, decode=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True, decode=False)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
        self._update_v(G, 0, H0, decode=True) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
             
            # decide subtype information
            H = self._get_vertex_state(G, idx)
            H_reg = torch.cat([H,type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(H_reg)
                reg_vals.append(subg_score)
                
            for j,g in enumerate(G):
                if not finished[j]:
                    g.add_vertex(type=new_types[j])
                    g_vals = [0] 
                    for reg_v in reg_vals:
                        subn_scores = reg_v[j,:]
                        if stochastic:
                            type_prob = F.softmax(subn_scores, dim=0).cpu().detach().numpy()
                            new_val = np.random.choice(range(self.subn_nvt), p=type_prob) + 1
                        else:  
                            new_val = torch.argmax(subn_scores, dim=0).tolist() + 1
                        g_vals.append(new_val)
                    g_vals.append(0)
                    g.vs[idx]['subg_nfeats'] = g_vals                    
            
            self._update_v(G, idx, decode=True)
            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx, decode=True)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005, reg_scale=0.01):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
        self._update_v(G, 0, H0, decode=True)
        res = 0  # log likelihood
        for v_true in range(1, self.max_n):
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            Hg = self._get_graph_state(G, decode=True) 
            
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()  
            res = res + vll
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
            self._update_v(G, v_true, decode=True)
            # calculate the mse loss of asubg nodes value
            H = self._get_vertex_state(G, v_true)
            H_reg = torch.cat([H, type_scores],dim=1)
            reg_vals = []
            for func in self.regs:
                subg_score = func(H_reg)
                reg_vals.append(subg_score) 
            reg_vals_list = []
            for i in range(len(G_true)):
                reg_vals_list.append(torch.cat([val[i].unsqueeze(0) for val in reg_vals], dim=0))
                #print(reg_vals_list[i].shape)
            vl2 = subg_loss_dis(reg_vals_list,G_true,v_true,device=self.get_device())  ######   
            res += vl2
            #res_mse += reg_scale * vl2
            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true, decode=True)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        #res += res_mse
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G



# Other baselines

class DVAE(nn.Module):
    def __init__(self, max_n, nvt, feat_nvt, START_TYPE, END_TYPE, hs=501, nz=56, bidirectional=False, vid=True, max_pos=8):
        super(DVAE, self).__init__()
        self.max_n = max_n  # maximum number of vertices
        self.max_pos = max_pos
        self.nvt = nvt  # number of vertex types 
        self.feat_nvt = feat_nvt + 1 # number of value type of each node in subgraphs
        self.START_TYPE = START_TYPE
        self.END_TYPE = END_TYPE
        self.hs = hs  # hidden state size of each vertex
        #assert(self.hs = 2 * self.emb_dim)
        self.nz = nz  # size of latent representation z
        self.gs = hs  # size of graph state
        self.bidir = bidirectional  # whether to use bidirectional encoding
        self.vid = vid
        self.device = None
        
        self.vs = hs 

        # 0. encoding-related
        self.grue_forward = nn.GRUCell(nvt + self.feat_nvt  + self.max_pos, hs)  # encoder GRU
        self.grue_backward = nn.GRUCell(nvt + self.feat_nvt  + self.max_pos, hs)  # backward encoder GRU
        self.fc1 = nn.Linear(self.gs, nz)  # latent mean
        self.fc2 = nn.Linear(self.gs, nz)  # latent logvar
            
        # 1. decoding-related
        self.grud = nn.GRUCell(nvt + self.feat_nvt  + self.max_pos, hs)  # decoder GRU
        self.fc3 = nn.Linear(nz, hs)  # from latent z to initial hidden state h0
        self.add_vertex = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, nvt)
                )  # which type of new vertex to add f(h0, hg)
        self.add_edge = nn.Sequential(
                nn.Linear(hs * 2, hs * 4), 
                nn.ReLU(), 
                nn.Linear(hs * 4, 1)
                )  # whether to add edge between v_i and v_new, f(hvi, hnew)
        self.fc_feat = nn.Sequential(
                nn.Linear(hs, hs),
                nn.ReLU(),
                nn.Linear(hs, self.feat_nvt)
                ) 
        self.vid_fc = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.ReLU(),
                nn.Linear(hs * 2, self.max_pos)
                )
        # 2. gate-related
        self.gate_forward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.gate_backward = nn.Sequential(
                nn.Linear(self.vs, hs), 
                nn.Sigmoid()
                )
        self.mapper_forward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False),
                )  # disable bias to ensure padded zeros also mapped to zeros
        self.mapper_backward = nn.Sequential(
                nn.Linear(self.vs, hs, bias=False), 
                )

        # 3. bidir-related, to unify sizes
        if self.bidir:
            self.hv_unify = nn.Sequential(
                    nn.Linear(hs * 2, hs), 
                    )
            self.hg_unify = nn.Sequential(
                    nn.Linear(self.gs * 2, self.gs), 
                    )

        # 4. other
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.logsoftmax1 = nn.LogSoftmax(1)

    def get_device(self):
        if self.device is None:
            self.device = next(self.parameters()).device
        return self.device
    
    def _get_zeros(self, n, length):
        return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state

    def _get_zero_hidden(self, n=1):
        return self._get_zeros(n, self.hs) # get a zero hidden state

    def _one_hot(self, idx, length):
        if type(idx) in [list, range]:
            if idx == []:
                return None
            idx = torch.LongTensor(idx).unsqueeze(0).t()
            x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
        else:
            idx = torch.LongTensor([idx]).unsqueeze(0)
            x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
        return x

    def _gated(self, h, gate, mapper):
        return gate(h) * mapper(h)

    def _collate_fn(self, G):
        return [g.copy() for g in G]

    def _propagate_to(self, G, v, propagator, H=None, reverse=False):
        # propagate messages to vertex index v for all graphs in G
        # return the new messages (states) at v
        G = [g for g in G if g.vcount() > v]
        if len(G) == 0:
            return
        if H is not None: # H: previous hidden state 
            idx = [i for i, g in enumerate(G) if g.vcount() > v]
            H = H[idx]
        v_types = [g.vs[v]['type'] for g in G]
        v_feats = [g.vs[v]['feat'] for g in G]
        vid_feats = [g.vs[v]['vid'] for g in G]
        X = self._one_hot(v_types, self.nvt)
        Y = self._one_hot(v_feats, self.feat_nvt)
        Z = self._one_hot(vid_feats, self.max_pos)
        X = torch.cat([X,Y, Z],dim=1)
        if reverse:
            H_name = 'H_backward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G] # hidden state of 'predecessors'
            #if self.vid:
            #    vids = [self._one_hot(g.successors(v), self.max_n) for g in G] # one hot of vertex index of 'predecessors'
            gate, mapper = self.gate_backward, self.mapper_backward
        else:
            H_name = 'H_forward'  # name of the hidden states attribute
            H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
            #if self.vid:
            #    vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
            gate, mapper = self.gate_forward, self.mapper_forward
        #if self.vid:
        #    H_pred = [[torch.cat([x[i], y[i:i+1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
        # if h is not provided, use gated sum of v's predecessors' states as the input hidden state
        if H is None:
            max_n_pred = max([len(x) for x in H_pred])  # maximum number of predecessors
            if max_n_pred == 0: ### start point
                H = self._get_zero_hidden(len(G))
            else:
                H_pred = [torch.cat(h_pred + 
                            [self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0) 
                            for h_pred in H_pred]  # pad all to same length
                H_pred = torch.cat(H_pred, 0)  # batch * max_n_pred * vs
                H = self._gated(H_pred, gate, mapper).sum(1)  # batch * hs
        Hv = propagator(X, H)
        for i, g in enumerate(G):
            g.vs[v][H_name] = Hv[i:i+1]
        return Hv

    def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
        # perform a series of propagation_to steps starting from v following a topo order
        # assume the original vertex indices are in a topological order
        if reverse:
            prop_order = range(v, -1, -1)
        else:
            prop_order = range(v, self.max_n)
        Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse)  # the initial vertex
        for v_ in prop_order[1:]:
            #print(v_)
            self._propagate_to(G, v_, propagator, reverse=reverse)
            # Hv = self._propagate_to(G, v_, propagator, Hv, reverse=reverse) no need
        return Hv

    def _update_v(self, G, v, H0=None):
        # perform a forward propagation step at v when decoding to update v's state
        self._propagate_to(G, v, self.grud, H0, reverse=False)
        return
    
    def _get_vertex_state(self, G, v):
        # get the vertex states at v
        Hv = []
        for g in G:
            if v >= g.vcount():
                hv = self._get_zero_hidden()
            else:
                hv = g.vs[v]['H_forward']
            Hv.append(hv)
        Hv = torch.cat(Hv, 0)
        return Hv

    def _get_graph_state(self, G, decode=False):
        # get the graph states
        # when decoding, use the last generated vertex's state as the graph state
        # when encoding, use the ending vertex state or unify the starting and ending vertex states
        Hg = []
        for g in G:
            hg = g.vs[g.vcount()-1]['H_forward']
            if self.bidir and not decode:  # decoding never uses backward propagation
                hg_b = g.vs[0]['H_backward']
                hg = torch.cat([hg, hg_b], 1)
            Hg.append(hg)
        Hg = torch.cat(Hg, 0)
        if self.bidir and not decode:
            Hg = self.hg_unify(Hg) # a linear model
        return Hg

    def encode(self, G):
        # encode graphs G into latent vectors
        if type(G) != list:
            G = [G]
        self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
                             reverse=False)
        if self.bidir:
            self._propagate_from(G, self.max_n-1, self.grue_backward, 
                                 H0=self._get_zero_hidden(len(G)), reverse=True)
        Hg = self._get_graph_state(G)
        mu, logvar = self.fc1(Hg), self.fc2(Hg) 
        return mu, logvar

    def reparameterize(self, mu, logvar, eps_scale=0.01):
        # return z ~ N(mu, std)
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std) * eps_scale
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _get_edge_score(self, Hvi, H, H0):
        # compute scores for edges from vi based on Hvi, H (current vertex) and H0
        # in most cases, H0 need not be explicitly included since Hvi and H contain its information
        return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))

    def decode(self, z, stochastic=True):
        # decode latent vectors z back to graphs
        # if stochastic=True, stochastically sample each action from the predicted distribution;
        # otherwise, select argmax action deterministically.
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['feat'] = 0
            g.vs[0]['vid'] = 0
        self._update_v(G, 0, H0) # only at the 'begining', we need a hidden state H0
        finished = [False] * len(G)
        for idx in range(1, self.max_n):
            # decide the type of the next added vertex
            if idx == self.max_n - 1:  # force the last node to be end_type
                new_types = [self.END_TYPE] * len(G)
                new_feats = [0] * len(G)
            else:
                Hg = self._get_graph_state(G, decode=True)
                type_scores = self.add_vertex(Hg)
                feat_scores = self.fc_feat(Hg)
                vid_scores = self.vid_fc(Hg)
                if stochastic:
                    type_probs = F.softmax(type_scores, 1).cpu().detach().numpy()
                    new_types = [np.random.choice(range(self.nvt), p=type_probs[i]) 
                                 for i in range(len(G))]
                    feat_probs = F.softmax(feat_scores, 1).cpu().detach().numpy()
                    new_feats = [np.random.choice(range(self.feat_nvt), p=feat_probs[i]) 
                                 for i in range(len(G))]
                    vid_probs = F.softmax(vid_scores, 1).cpu().detach().numpy() 
                    new_vids = [np.random.choice(range(self.max_pos), p=vid_probs[i]) 
                                 for i in range(len(G))]
                else:
                    new_types = torch.argmax(type_scores, 1)
                    new_types = new_types.flatten().tolist()
                    new_feats= torch.argmax(feat_scores, 1)
                    new_feats = new_feats.flatten().tolist()
                    new_vids = torch.argmax(vid_scores, 1)
                    new_vids = new_vids.flatten().tolist()
            
            for i, g in enumerate(G):
                if not finished[i]:
                    g.add_vertex(type=new_types[i])
                    g.vs[idx]['feat'] = new_feats[i]
                    g.vs[idx]['vid'] = new_vids[i]
            self._update_v(G, idx)

            # decide connections
            edge_scores = []
            for vi in range(idx-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, idx)
                ei_score = self._get_edge_score(Hvi, H, H0)
                if stochastic:
                    random_score = torch.rand_like(ei_score)
                    decisions = random_score < ei_score
                else:
                    decisions = ei_score > 0.5
                for i, g in enumerate(G):
                    if finished[i]:
                        continue
                    if new_types[i] == self.END_TYPE: 
                    # if new node is end_type, connect it to all loose-end vertices (out_degree==0)
                        end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) 
                                            if v.index != g.vcount()-1])
                        for v in end_vertices:
                            g.add_edge(v, g.vcount()-1)
                        finished[i] = True
                        continue
                    if decisions[i, 0]:
                        g.add_edge(vi, g.vcount()-1)
                self._update_v(G, idx)

        for g in G:
            del g.vs['H_forward']  # delete hidden states to save GPU memory
        return G

    def loss(self, mu, logvar, G_true, beta=0.005):
        # compute the loss of decoding mu and logvar to true graphs using teacher forcing
        # ensure when computing the loss of step i, steps 0 to i-1 are correct
        z = self.reparameterize(mu, logvar) # (bsize, hidden)
        H0 = self.tanh(self.fc3(z))  # or relu activation, similar performance
        G = [igraph.Graph(directed=True) for _ in range(len(z))]
        for g in G:
            g.add_vertex(type=self.START_TYPE)
            g.vs[0]['feat'] = 0
            g.vs[0]['vid'] = 0
        self._update_v(G, 0, H0)
        res = 0  # log likelihood
        for v_true in range(1, self.max_n):
            #print(v_true)
            # calculate the likelihood of adding true types of nodes
            # use start type to denote padding vertices since start type only appears for vertex 0 
            # and will never be a true type for later vertices, thus it's free to use
            true_types = [g_true.vs[v_true]['type'] if v_true < g_true.vcount()  # (bsize, 1)
                          else self.START_TYPE for g_true in G_true]
            true_feats = [g_true.vs[v_true]['feat'] if v_true < g_true.vcount()  # (bsize, 1)
                          else 0 for g_true in G_true]
            true_vids = [g_true.vs[v_true]['vid'] if v_true < g_true.vcount()  # (bsize, 1)
                          else 0 for g_true in G_true]
            #print(v_true)
            #print(true_types)
            
            #print(true_vids)
            
            Hg = self._get_graph_state(G, decode=True) 
            type_scores = self.add_vertex(Hg) # (bsize, self.vrt)
            feat_scores = self.fc_feat(Hg)
            vid_scores = self.vid_fc(Hg)
            # vertex log likelihood
            vll = self.logsoftmax1(type_scores)[np.arange(len(G)), true_types].sum()  
            vl2 = self.logsoftmax1(feat_scores)[np.arange(len(G)), true_feats].sum() 
            vl3 = self.logsoftmax1(vid_scores)[np.arange(len(G)), true_vids].sum() 
            res = res + vll + vl2 + vl3
            for i, g in enumerate(G):
                if true_types[i] != self.START_TYPE:
                    g.add_vertex(type=true_types[i])
                    g.vs[v_true]['feat'] = true_feats[i]
                    g.vs[v_true]['vid'] = true_vids[i]
            #print(g.vs[1])
            self._update_v(G, v_true)

            # calculate the likelihood of adding true edges
            true_edges = []
            for i, g_true in enumerate(G_true):
                true_edges.append(g_true.get_adjlist(igraph.IN)[v_true] if v_true < g_true.vcount() 
                                  else []) # get_idjlist: return a list of node index to show these directed edges. true_edges[i] = in ith graph, v_true's predecessors
            edge_scores = []
            for vi in range(v_true-1, -1, -1):
                Hvi = self._get_vertex_state(G, vi)
                H = self._get_vertex_state(G, v_true)
                ei_score = self._get_edge_score(Hvi, H, H0) # size: batch size, 1
                edge_scores.append(ei_score)
                for i, g in enumerate(G):
                    if vi in true_edges[i]:
                        g.add_edge(vi, v_true)
                self._update_v(G, v_true)
            edge_scores = torch.cat(edge_scores[::-1], 1)  # (batch size, v_true): columns: v_true-1, ... 0

            ground_truth = torch.zeros_like(edge_scores)
            idx1 = [i for i, x in enumerate(true_edges) for _ in range(len(x))]
            idx2 = [xx for x in true_edges for xx in x]
            ground_truth[idx1, idx2] = 1.0

            # edges log-likelihood
            ell = - F.binary_cross_entropy(edge_scores, ground_truth, reduction='sum') 
            res = res + ell

        res = -res  # convert likelihood to loss
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return res + beta*kld, res, kld

    def encode_decode(self, G):
        mu, logvar = self.encode(G)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def forward(self, G):
        mu, logvar = self.encode(G)
        loss, _, _ = self.loss(mu, logvar, G)
        return loss
    
    def generate_sample(self, n):
        sample = torch.randn(n, self.nz).to(self.get_device())
        G = self.decode(sample)
        return G









