import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from torch.nn import GRU
import torch_geometric.nn as gnn
from torch_geometric.nn import MessagePassing
from .mlp import MLP
import torch_geometric as pyg
import torch_scatter
from torch.nn.functional import triplet_margin_loss
from torch.nn.utils.rnn import pad_sequence
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import argparse


def generate_negative_pair_idx(bs):

    first_idx = torch.arange(0, bs)
    second_idx = torch.randperm(bs)
    
    mask = (first_idx == second_idx)
    while mask.any():
        second_idx[mask] = torch.randperm(bs)[mask]
        mask = (first_idx == second_idx)
    
    return torch.stack([first_idx, second_idx])

class NeuroMatch(pl.LightningModule):

    def __init__(self, encoder_type):
        super().__init__()  
        args = parse_encoder()
        self.args = args
        self.encoder_type = encoder_type
        self.model = OrderEmbedder(input_dim=3, hidden_dim=args.hidden_dim, encoder_type = encoder_type,args=args)

        self.training_step_outputs = []
        self.test_step_outputs = []
        self.val_step_outputs = []

    def compute_metrics(self, preds, labels):
        TP = ((preds == 1) & (labels == 1)).sum().item() / preds.shape[0]
        FP = ((preds == 1) & (labels == 0)).sum().item() / preds.shape[0]
        TN = ((preds == 0) & (labels == 0)).sum().item() / preds.shape[0]
        FN = ((preds == 0) & (labels == 1)).sum().item() / preds.shape[0]
        return TP, FP, TN, FN

    def compute_PR(self, predictions, labels):

        predictions = predictions.view(-1)
        labels = labels.view(-1)

        true_positives = torch.logical_and(predictions == 1, labels == 1).sum().item()
        false_positives = torch.logical_and(predictions == 1, labels == 0).sum().item()
        false_negatives = torch.logical_and(predictions == 0, labels == 1).sum().item()

        # Precision: TP / (TP + FP)
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0

        # Recall: TP / (TP + FN)
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0

        # F1-Score: 2 * (Precision * Recall) / (Precision + Recall)
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        metrics = {
            "Precision": precision,
            "Recall": recall,
            "F1-Score": f1
        }
        return precision, recall, f1

    def forward_ordered(self, batch, batch_idx):
        
        device = batch.x.device
        bs = batch.batch_size

        if self.encoder_type == 'ABGNN':
            #encode aig
            orig_hf = self.model.emb_model(F.one_hot(batch.gate.squeeze(-1), num_classes=3).float(), batch.edge_index, batch.forward_level, batch.batch)
            rd_hf = self.model.emb_model(F.one_hot(batch.rd_gate.squeeze(-1), num_classes=3).float(), batch.rd_edge_index, batch.rd_forward_level, batch.rd_batch)
            syn_hf = self.model.emb_model(F.one_hot(batch.syn_gate.squeeze(-1), num_classes=3).float(), batch.syn_edge_index, batch.syn_forward_level,batch.syn_batch)
            #encode pm
            pm_hf = self.model.emb_model_pm(batch.pm_x, batch.pm_edge_index, batch.pm_forward_level, batch.pm_batch)
        else:
            #encode aig
            orig_hf = self.model.emb_model(F.one_hot(batch.gate.squeeze(-1), num_classes=3).float(), batch.edge_index, batch.batch)
            rd_hf = self.model.emb_model(F.one_hot(batch.rd_gate.squeeze(-1), num_classes=3).float(), batch.rd_edge_index, batch.rd_batch)
            syn_hf = self.model.emb_model(F.one_hot(batch.syn_gate.squeeze(-1), num_classes=3).float(), batch.syn_edge_index, batch.syn_batch)
            #encode pm
            pm_hf = self.model.emb_model_pm(batch.pm_x, batch.pm_edge_index, batch.pm_batch)

        ##########################
        #####subgraph mining######
        ##########################

        neg_pair_idx = generate_negative_pair_idx(bs).to(device)
        
        #aig & sub
        orig_pos_pair = torch.cat([rd_hf, orig_hf],dim=-1)
        orig_neg_pair = torch.cat([rd_hf[neg_pair_idx[0]], orig_hf[neg_pair_idx[1]]],dim=-1)
        orig_pair = torch.cat([orig_pos_pair, orig_neg_pair],dim=0)

        orig_label = torch.cat([torch.ones(bs), torch.zeros(bs)],dim=0).long().to(device)
        L_orig = self.model.criterion(orig_pair, None, orig_label)

        orig_pred = self.model.predict(orig_pair)
        orig_pred = self.model.clf_model(orig_pred.unsqueeze(1)).argmax(dim=-1)
        orig_acc = (orig_pred == orig_label).sum().item() / orig_label.shape[0]
 

        #syn & sub
        syn_pos_pair = torch.cat([rd_hf, syn_hf],dim=-1)
        syn_neg_pair = torch.cat([rd_hf[neg_pair_idx[0]], syn_hf[neg_pair_idx[1]]],dim=-1)
        syn_pair = torch.cat([syn_pos_pair, syn_neg_pair],dim=0)
        syn_label = torch.cat([torch.ones(bs), torch.zeros(bs)],dim=0).long().to(device)
        L_syn = self.model.criterion(syn_pair, None, syn_label)

        syn_pred = self.model.predict(syn_pair)
        syn_pred = self.model.clf_model(syn_pred.unsqueeze(1)).argmax(dim=-1)
        syn_acc = (syn_pred == syn_label).sum().item() / syn_label.shape[0]

        #pm & sub
        pm_pos_pair = torch.cat([rd_hf, pm_hf],dim=-1)
        pm_neg_pair = torch.cat([rd_hf[neg_pair_idx[0]], pm_hf[neg_pair_idx[1]]],dim=-1)
        pm_pair = torch.cat([pm_pos_pair, pm_neg_pair],dim=0)
        pm_label = torch.cat([torch.ones(bs), torch.zeros(bs)],dim=0).long().to(device)
        L_pm = self.model.criterion(pm_pair, None, pm_label)

        pm_pred = self.model.predict(pm_pair)
        pm_pred = self.model.clf_model(pm_pred.unsqueeze(1)).argmax(dim=-1)
        pm_acc = (pm_pred == pm_label).sum().item() / pm_label.shape[0]

        syn_prec, syn_rec, syn_f1 = self.compute_PR(syn_pred, syn_label)
        pm_prec, pm_rec, pm_f1 = self.compute_PR(pm_pred, pm_label)
        
        metrics = {
            'syn':
                {
                "Precision": syn_prec,
                "Recall": syn_rec,
                "F1-Score": syn_f1
                },
            'pm':
                {
                "Precision": pm_prec,
                "Recall": pm_rec,
                "F1-Score": pm_f1
                }
        }

        loss = L_orig + L_syn + L_pm
        loss_align = 0

        return loss, loss_align, orig_acc, syn_acc, pm_acc, metrics
    
    def training_step(self, batch, batch_idx):


        loss, loss_align, orig_acc, syn_acc, pm_acc, metrics = self.forward_ordered(batch, batch_idx)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('train_align_loss', loss_align, on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)

        self.training_step_outputs.append({'loss': loss+loss_align, 'orig_acc': orig_acc, 'syn_acc': syn_acc, 'pm_acc': pm_acc, 'metrics':metrics})

        return loss + loss_align
    
    def validation_step(self, batch, batch_idx):

        loss, loss_align, orig_acc, syn_acc, pm_acc, metrics = self.forward_ordered(batch, batch_idx)
 
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_align_loss', loss_align, on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)

        self.val_step_outputs.append({'loss': loss+loss_align, 'orig_acc': orig_acc, 'syn_acc': syn_acc, 'pm_acc': pm_acc,'metrics':metrics})

        return loss + loss_align
    
    def on_train_epoch_end(self):

        orig_acc = sum([x['orig_acc'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        syn_acc = sum([x['syn_acc'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        pm_acc = sum([x['pm_acc'] for x in self.training_step_outputs])/len(self.training_step_outputs)

        pm_prec = sum([x['metrics']['pm']['Precision'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        pm_rec = sum([x['metrics']['pm']['Recall'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        pm_f1 = sum([x['metrics']['pm']['F1-Score'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        syn_prec = sum([x['metrics']['syn']['Precision'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        syn_rec = sum([x['metrics']['syn']['Recall'] for x in self.training_step_outputs])/len(self.training_step_outputs)
        syn_f1 = sum([x['metrics']['syn']['F1-Score'] for x in self.training_step_outputs])/len(self.training_step_outputs)


        self.log('train_orig_accuarcy_epoch', round(float(orig_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_syn_accuarcy_epoch', round(float(syn_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_pm_accuarcy_epoch', round(float(pm_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_syn_precision_epoch', round(float(syn_prec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_syn_recall_epoch', round(float(syn_rec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_syn_f1_epoch', round(float(syn_f1),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_pm_precision_epoch', round(float(pm_prec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_pm_recall_epoch', round(float(pm_rec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('train_pm_f1_epoch', round(float(pm_f1),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        

        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):

        orig_acc = sum([x['orig_acc'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        syn_acc = sum([x['syn_acc'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        pm_acc = sum([x['pm_acc'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        pm_prec = sum([x['metrics']['pm']['Precision'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        pm_rec = sum([x['metrics']['pm']['Recall'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        pm_f1 = sum([x['metrics']['pm']['F1-Score'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        syn_prec = sum([x['metrics']['syn']['Precision'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        syn_rec = sum([x['metrics']['syn']['Recall'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        syn_f1 = sum([x['metrics']['syn']['F1-Score'] for x in self.val_step_outputs])/len(self.val_step_outputs)

        self.log('val_orig_accuarcy_epoch', round(float(orig_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_syn_accuarcy_epoch', round(float(syn_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_pm_accuarcy_epoch', round(float(pm_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_syn_precision_epoch', round(float(syn_prec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_syn_recall_epoch', round(float(syn_rec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_syn_f1_epoch', round(float(syn_f1),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_pm_precision_epoch', round(float(pm_prec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_pm_recall_epoch', round(float(pm_rec),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_pm_f1_epoch', round(float(pm_f1),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        
        self.val_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
        return optimizer
    
# def parse_encoder():
#     enc_parser = argparse.ArgumentParser()

#     enc_parser.add_argument('--conv_type', type=str,
#                         help='type of convolution')
#     enc_parser.add_argument('--encoder_type', type=str,
#                         help='type of encoder')
#     enc_parser.add_argument('--method_type', type=str,
#                         help='type of embedding')
#     enc_parser.add_argument('--batch_size', type=int,
#                         help='Training batch size')
#     enc_parser.add_argument('--n_layers', type=int,
#                         help='Number of graph conv layers')
#     enc_parser.add_argument('--hidden_dim', type=int,
#                         help='Training hidden size')
#     enc_parser.add_argument('--skip', type=str,
#                         help='"all" or "last"')
#     enc_parser.add_argument('--dropout', type=float,
#                         help='Dropout rate')
#     enc_parser.add_argument('--n_batches', type=int,
#                         help='Number of training minibatches')
#     enc_parser.add_argument('--margin', type=float,
#                         help='margin for loss')
#     enc_parser.add_argument('--dataset', type=str,
#                         help='Dataset')
#     enc_parser.add_argument('--test_set', type=str,
#                         help='test set filename')
#     enc_parser.add_argument('--eval_interval', type=int,
#                         help='how often to eval during training')
#     enc_parser.add_argument('--val_size', type=int,
#                         help='validation set size')
#     enc_parser.add_argument('--model_path', type=str,
#                         help='path to save/load model')
#     enc_parser.add_argument('--opt_scheduler', type=str,
#                         help='scheduler name')
#     enc_parser.add_argument('--node_anchored', action="store_true",
#                         help='whether to use node anchoring in training')
#     enc_parser.add_argument('--test', action="store_true")
#     enc_parser.add_argument('--n_workers', type=int)
#     enc_parser.add_argument('--tag', type=str,
#         help='tag to identify the run')

#     enc_parser.set_defaults(conv_type='SAGE',
#                         method_type='order',
#                         dataset='syn',
#                         n_layers=8,
#                         batch_size=64,
#                         hidden_dim=128,
#                         skip="learnable",
#                         dropout=0.2,
#                         n_batches=1000000,
#                         opt='adam',   # opt_enc_parser
#                         opt_scheduler='none',
#                         opt_restart=100,
#                         weight_decay=0.0,
#                         lr=1e-4,
#                         margin=0.1,
#                         test_set='',
#                         eval_interval=1000,
#                         n_workers=4,
#                         model_path="ckpt/model.pt",
#                         tag='',
#                         val_size=4096,
#                         node_anchored=True)

#     return enc_parser.parse_args()

def parse_encoder():
    default_args = {
        'conv_type': 'SAGE',
        'method_type': 'order',
        'dataset': 'syn',
        'n_layers': 8,
        'batch_size': 64,
        'hidden_dim': 128,
        'skip': "learnable",
        'dropout': 0.2,
        'n_batches': 1000000,
        'opt': 'adam',
        'opt_scheduler': 'none',
        'opt_restart': 100,
        'weight_decay': 0.0,
        'lr': 1e-4,
        'margin': 0.1,
        'test_set': '',
        'eval_interval': 1000,
        'n_workers': 4,
        'model_path': "ckpt/model.pt",
        'tag': '',
        'val_size': 4096,
        'node_anchored': True,
        'test': False  # default is False for flag
    }
    args = argparse.Namespace(**default_args)
    return args


class OrderEmbedder(nn.Module):
    def __init__(self, input_dim, hidden_dim, encoder_type, args):
        super(OrderEmbedder, self).__init__()
        
        self.hidden_dim = hidden_dim
        if encoder_type == 'NeuroMatch':
            self.emb_model = SkipLastGNN(input_dim, hidden_dim, hidden_dim, args)
            self.emb_model_pm = SkipLastGNN(64, hidden_dim, hidden_dim, args)
        elif encoder_type == 'Gamora':
            self.emb_model = SAGE_MULT(input_dim, hidden_dim)
            self.emb_model_pm = SAGE_MULT(64, hidden_dim)
        elif encoder_type == 'ABGNN':
            self.emb_model = ABGNN(input_dim, hidden_dim, dropout=0.2)
            self.emb_model_pm = ABGNN(64, hidden_dim, dropout=0.2)
        elif encoder_type == 'HGCN':
            self.emb_model = GCN(input_dim, hidden_dim, 3, dropout=0.2)
            self.emb_model_pm = HGCN(64, hidden_dim, 3, dropout=0.2)

        self.margin = args.margin
        self.use_intersection = False

        self.clf_model = nn.Sequential(nn.Linear(1, 2), nn.LogSoftmax(dim=-1))

    def forward(self, emb_as, emb_bs):
        return emb_as, emb_bs

    def predict(self, pred):
        """Predict if b is a subgraph of a (batched), where emb_as, emb_bs = pred.

        pred: list (emb_as, emb_bs) of embeddings of graph pairs

        Returns: list of bools (whether a is subgraph of b in the pair)
        """
        # emb_as, emb_bs = pred
        emb_as = pred[:,:self.hidden_dim ]
        emb_bs = pred[:,self.hidden_dim:]

        e = torch.sum(torch.max(torch.zeros_like(emb_as,
            device=emb_as.device), emb_bs - emb_as)**2, dim=1)
        return e

    def criterion(self, pred, intersect_embs, labels):
        """Loss function for order emb.
        The e term is the amount of violation (if b is a subgraph of a).
        For positive examples, the e term is minimized (close to 0); 
        for negative examples, the e term is trained to be at least greater than self.margin.

        pred: lists of embeddings outputted by forward
        intersect_embs: not used
        labels: subgraph labels for each entry in pred
        """
        # emb_as, emb_bs = pred
        emb_as = pred[:,:self.hidden_dim ]
        emb_bs = pred[:,self.hidden_dim:]

        e = torch.sum(torch.max(torch.zeros_like(emb_as,
            device=emb_as.device), emb_bs - emb_as)**2, dim=1)

        margin = self.margin
        e[labels == 0] = torch.max(torch.tensor(0.0,
            device=emb_as.device), margin - e)[labels == 0]

        relation_loss = torch.sum(e)

        return relation_loss

class SkipLastGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args):
        super(SkipLastGNN, self).__init__()
        self.dropout = args.dropout
        self.n_layers = args.n_layers

        self.feat_preprocess = None

        self.pre_mp = nn.Sequential(nn.Linear(input_dim, 3*hidden_dim if
            args.conv_type == "PNA" else hidden_dim))

        conv_model = self.build_conv_model(args.conv_type, 1)
        if args.conv_type == "PNA":
            self.convs_sum = nn.ModuleList()
            self.convs_mean = nn.ModuleList()
            self.convs_max = nn.ModuleList()
        else:
            self.convs = nn.ModuleList()

        if args.skip == 'learnable':
            self.learnable_skip = nn.Parameter(torch.ones(self.n_layers,
                self.n_layers))

        for l in range(args.n_layers):
            if args.skip == 'all' or args.skip == 'learnable':
                hidden_input_dim = hidden_dim * (l + 1)
            else:
                hidden_input_dim = hidden_dim
            if args.conv_type == "PNA":
                self.convs_sum.append(conv_model(3*hidden_input_dim, hidden_dim))
                self.convs_mean.append(conv_model(3*hidden_input_dim, hidden_dim))
                self.convs_max.append(conv_model(3*hidden_input_dim, hidden_dim))
            else:
                self.convs.append(conv_model(hidden_input_dim, hidden_dim))

        post_input_dim = hidden_dim * (args.n_layers + 1)
        if args.conv_type == "PNA":
            post_input_dim *= 3
        self.post_mp = nn.Sequential(
            nn.Linear(post_input_dim, hidden_dim), nn.Dropout(args.dropout),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_dim, output_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 256), nn.ReLU(),
            nn.Linear(256, hidden_dim))
        #self.batch_norm = nn.BatchNorm1d(output_dim, eps=1e-5, momentum=0.1)
        self.skip = args.skip
        self.conv_type = args.conv_type

    def build_conv_model(self, model_type, n_inner_layers):
        if model_type == "GCN":
            return pyg_nn.GCNConv
        elif model_type == "GIN":
            #return lambda i, h: pyg_nn.GINConv(nn.Sequential(
            #    nn.Linear(i, h), nn.ReLU()))
            return lambda i, h: GINConv(nn.Sequential(
                nn.Linear(i, h), nn.ReLU(), nn.Linear(h, h)
                ))
        elif model_type == "SAGE":
            return SAGEConv
        elif model_type == "graph":
            return pyg_nn.GraphConv
        elif model_type == "GAT":
            return pyg_nn.GATConv
        elif model_type == "gated":
            return lambda i, h: pyg_nn.GatedGraphConv(h, n_inner_layers)
        elif model_type == "PNA":
            return SAGEConv
        else:
            print("unrecognized model type")

    def forward(self, x, edge_index, batch, is_boundary=False):
        x = self.pre_mp(x)

        all_emb = x.unsqueeze(1)
        emb = x
        for i in range(len(self.convs_sum) if self.conv_type=="PNA" else
            len(self.convs)):
            if self.skip == 'learnable':
                skip_vals = self.learnable_skip[i,
                    :i+1].unsqueeze(0).unsqueeze(-1)
                curr_emb = all_emb * torch.sigmoid(skip_vals)
                curr_emb = curr_emb.view(x.size(0), -1)
                if self.conv_type == "PNA":
                    x = torch.cat((self.convs_sum[i](curr_emb, edge_index),
                        self.convs_mean[i](curr_emb, edge_index),
                        self.convs_max[i](curr_emb, edge_index)), dim=-1)
                else:
                    x = self.convs[i](curr_emb, edge_index)
            elif self.skip == 'all':
                if self.conv_type == "PNA":
                    x = torch.cat((self.convs_sum[i](emb, edge_index),
                        self.convs_mean[i](emb, edge_index),
                        self.convs_max[i](emb, edge_index)), dim=-1)
                else:
                    x = self.convs[i](emb, edge_index)
            else:
                x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            emb = torch.cat((emb, x), 1)
            if self.skip == 'learnable':
                all_emb = torch.cat((all_emb, x.unsqueeze(1)), 1)

        node_emb = self.post_mp(emb.clone())
        emb = pyg_nn.global_add_pool(emb, batch)
        emb = self.post_mp(emb)
        if is_boundary==True:
            return emb, node_emb
        else:
            return emb

    def loss(self, pred, label):
        return F.nll_loss(pred, label)

class SAGE_MULT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=4,
                 dropout=0.5):
        super(SAGE_MULT, self).__init__()
        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(gnn.SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(gnn.SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(gnn.SAGEConv(hidden_channels, hidden_channels))
        
        # two linear layer for predictions
        self.linear = torch.nn.ModuleList()
        self.linear.append(nn.Linear(hidden_channels, hidden_channels, bias=False))
        self.bn0 = nn.BatchNorm1d(hidden_channels)

        self.post_mp = nn.Linear(hidden_channels, hidden_channels, bias=False)
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.linear:
            lin.reset_parameters()

    def forward(self, x, edge_index, batch, is_boundary=False):
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
            
        x = self.linear[0](x)
        x = self.bn0(F.relu(x))
        emb = pyg_nn.global_add_pool(x, batch)
        emb = self.post_mp(emb)
        if is_boundary==True:
            return emb, x
        else:
            return emb

class ABGNN(nn.Module):
    r"""
                    Description
                    -----------
                    our Asynchronous Bidirectional Graph Nueral Network (ABGNN) model

                    Note that this model is ony for one direction
        """
    def __init__(
        self,
        in_dim,      # dim of the input layer
        hidden_dim,    # dim of the hidden layers
        dropout,       # dropout rate
        n_layers=None,  # number of layers
        activation=torch.relu, #activation function
    ):
        super(ABGNN, self).__init__()
        self.activation = activation
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)
        self.layers = nn.ModuleList()
        self.fc_init = nn.Linear(in_dim,hidden_dim)
        in_dim = hidden_dim

        self.post_mp = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.conv = gnn.SAGEConv(
            in_dim,
            hidden_dim,
            aggr='add',
        )

    def forward(self, features, edge_index, forward_level, batch, is_boundary=False):
        depth = forward_level.max()+1
        h = self.activation(self.fc_init(features))
        for i in range(depth):
            if i != 0:
                h = self.dropout(h)
            edge_i = edge_index[:, forward_level[edge_index[1]] == i]
            h = self.conv(h, edge_i) # the generated node embeddings of current layer
            if i != depth - 1:
                h = self.activation(h)
        h = h.squeeze(1)
        emb = pyg_nn.global_add_pool(h, batch)
        emb = self.post_mp(emb)
        if is_boundary==True:
            return emb, h
        else:
            return emb

class SAGEConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels, aggr="add"):
        super(SAGEConv, self).__init__(aggr=aggr)

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin = nn.Linear(in_channels, out_channels)
        self.lin_update = nn.Linear(out_channels + in_channels,
            out_channels)

    def forward(self, x, edge_index, edge_weight=None, size=None,
                res_n_id=None):
        """
        Args:
            res_n_id (Tensor, optional): Residual node indices coming from
                :obj:`DataFlow` generated by :obj:`NeighborSampler` are used to
                select central node features in :obj:`x`.
                Required if operating in a bipartite graph and :obj:`concat` is
                :obj:`True`. (default: :obj:`None`)
        """
        #edge_index, edge_weight = add_remaining_self_loops(
        #    edge_index, edge_weight, 1, x.size(self.node_dim))
        edge_index, _ = pyg_utils.remove_self_loops(edge_index)

        return self.propagate(edge_index, size=size, x=x,
                              edge_weight=edge_weight, res_n_id=res_n_id)

    def message(self, x_j, edge_weight):
        #return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
        return self.lin(x_j)

    def update(self, aggr_out, x, res_n_id):
        aggr_out = torch.cat([aggr_out, x], dim=-1)

        aggr_out = self.lin_update(aggr_out)
        #aggr_out = torch.matmul(aggr_out, self.weight)

        #if self.bias is not None:
        #    aggr_out = aggr_out + self.bias

        #if self.normalize:
        #    aggr_out = F.normalize(aggr_out, p=2, dim=-1)

        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

class GINConv(pyg_nn.MessagePassing):
    def __init__(self, nn, eps=0, train_eps=False, **kwargs):
        super(GINConv, self).__init__(aggr='add', **kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        #reset(self.nn)
        self.eps.data.fill_(self.initial_eps)

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        edge_index, edge_weight = pyg_utils.remove_self_loops(edge_index,
            edge_weight)
        out = self.nn((1 + self.eps) * x + self.propagate(edge_index, x=x,
            edge_weight=edge_weight))
        return out

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def __repr__(self):
        return '{}(nn={})'.format(self.__class__.__name__, self.nn)

class HGCN(nn.Module):
    def __init__(
        self,
        in_dim,      # dim of the input layer
        hidden_dim,    # dim of the hidden layers
        n_layers,  # number of layers
        dropout, 
        activation=torch.relu, #activation function
    ):
        super(HGCN, self).__init__()
        self.activation = activation
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)

        self.post_mp = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.convs = torch.nn.ModuleList()
        # Input hypergraph convolution layer
        self.convs.append(pyg_nn.HypergraphConv(in_dim, hidden_dim))
        # Hidden layers (if any)
        for _ in range(n_layers - 1):
            self.convs.append(pyg_nn.HypergraphConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, batch, is_boundary=False):
        # get hyperedge
        hyperedge = torch.cat([edge_index,torch.stack([edge_index[1], edge_index[1]])], dim=1)
        for conv in self.convs:
            x = conv(x, hyperedge)
        emb = pyg_nn.global_add_pool(x, batch)
        emb = self.post_mp(emb)
        if is_boundary==True:
            return emb, x
        else:
            return emb

class GCN(nn.Module):
    def __init__(
        self,
        in_dim,      # dim of the input layer
        hidden_dim,    # dim of the hidden layers
        n_layers,  # number of layers
        dropout, 
        activation=torch.relu, #activation function
    ):
        super(GCN, self).__init__()
        self.activation = activation
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)

        self.post_mp = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.convs = torch.nn.ModuleList()
        # Input hypergraph convolution layer
        self.convs.append(pyg_nn.GCNConv(in_dim, hidden_dim))
        # Hidden layers (if any)
        for _ in range(n_layers - 1):
            self.convs.append(pyg_nn.GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, batch, is_boundary=False):
        for conv in self.convs:
            x = conv(x, edge_index)
        emb = pyg_nn.global_add_pool(x, batch)
        emb = self.post_mp(emb)
        if is_boundary==True:
            return emb, x
        else:
            return emb
        
