from __future__ import division
from __future__ import print_function

import time
import tensorflow as tf

import scipy.sparse as sp

from utils import *
from models import GCN_APPRO_Mix
import json
from networkx.readwrite import json_graph
import os

from dgl.data import reddit
from dgl import DGLGraph
import networkx as nx
from dgl import transform

import numpy as np
from dgl.data import gnn_benckmark as gnnbnch
from csv_json_data_loader import snap_data_loader
from sklearn.metrics import f1_score
from scipy.sparse import csr_matrix

# to read data off of dgl libraries

# Set random seed
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)


# clear existing flags

def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()
    keys_list = [keys for keys in flags_dict]
    for keys in keys_list:
        FLAGS.__delattr__(keys)


del_all_flags(tf.flags.FLAGS)

# Settings
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('dataset', 'pubmed', 'Dataset string.')  # 'cora', 'citeseer', 'pubmed'
flags.DEFINE_string('model', 'gcn_mix', 'Model string.')  # 'gcn', 'gcn_appr'
flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
flags.DEFINE_integer('epochs', 200, 'Number of epochs to train.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_float('dropout', 0.0, 'Dropout rate (1 - keep probability).')
flags.DEFINE_float('weight_decay', 1e-4, 'Weight for L2 loss on embedding matrix.')
flags.DEFINE_integer('early_stopping', 30, 'Tolerance for early stopping (# of epochs).')
flags.DEFINE_integer('max_degree', 3, 'Maximum Chebyshev polynomial degree.')


# Load data


def iterate_minibatches_listinputs(inputs, batchsize, shuffle=False):
    assert inputs is not None
    numSamples = inputs[0].shape[0]
    if shuffle:
        indices = np.arange(numSamples)
        np.random.shuffle(indices)
    for start_idx in range(0, numSamples - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield [input[excerpt] for input in inputs]


def loadRedditFromG(dataset_dir, inputfile):
    f = open(dataset_dir + inputfile)
    objects = []
    for _ in range(pkl.load(f)):
        objects.append(pkl.load(f))
    adj, train_labels, val_labels, test_labels, train_index, val_index, test_index = tuple(objects)
    feats = np.load(dataset_dir + "/reddit-feats.npy")
    return sp.csr_matrix(adj), sp.lil_matrix(
        feats), train_labels, val_labels, test_labels, train_index, val_index, test_index


def loadRedditFromNPZ(dataset_dir):
    adj = sp.load_npz(dataset_dir + "reddit_adj.npz")
    data = np.load(dataset_dir + "reddit.npz")  # ,allow_pickle=True

    return adj, data['feats'], data['y_train'], data['y_val'], data['y_test'], data['train_index'], data['val_index'], \
           data['test_index']


def transferRedditDataFormat(dataset_dir, output_file):
    G = json_graph.node_link_graph(json.load(open(dataset_dir + "/reddit-G.json")))
    labels = json.load(open(dataset_dir + "/reddit-class_map.json"))

    train_ids = [n for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]
    test_ids = [n for n in G.nodes() if G.node[n]['test']]
    val_ids = [n for n in G.nodes() if G.node[n]['val']]
    train_labels = [labels[i] for i in train_ids]
    test_labels = [labels[i] for i in test_ids]
    val_labels = [labels[i] for i in val_ids]
    feats = np.load(dataset_dir + "/reddit-feats.npy")
    ## Logistic gets thrown off by big counts, so log transform num comments and score
    feats[:, 0] = np.log(feats[:, 0] + 1.0)
    feats[:, 1] = np.log(feats[:, 1] - min(np.min(feats[:, 1]), -1))
    feat_id_map = json.load(open(dataset_dir + "reddit-id_map.json"))
    feat_id_map = {id: val for id, val in feat_id_map.iteritems()}

    # train_feats = feats[[feat_id_map[id] for id in train_ids]]
    # test_feats = feats[[feat_id_map[id] for id in test_ids]]

    # numNode = len(feat_id_map)
    # adj = sp.lil_matrix(np.zeros((numNode,numNovade)))
    # for edge in G.edges():
    #     adj[feat_id_map[edge[0]], feat_id_map[edge[1]]] = 1

    train_index = [feat_id_map[id] for id in train_ids]
    val_index = [feat_id_map[id] for id in val_ids]
    test_index = [feat_id_map[id] for id in test_ids]
    np.savez(output_file, feats=feats, y_train=train_labels, y_val=val_labels, y_test=test_labels,
             train_index=train_index,
             val_index=val_index, test_index=test_index)


def transferLabel2Onehot(labels, N):
    y = np.zeros((len(labels), N))
    for i in range(len(labels)):
        pos = labels[i]
        y[i, pos] = 1
    return y


def construct_feeddict_forMixlayers(AXfeatures, support, labels, placeholders):
    feed_dict = dict()
    feed_dict.update({placeholders['labels']: labels})
    feed_dict.update({placeholders['AXfeatures']: AXfeatures})
    feed_dict.update({placeholders['support']: support})
    feed_dict.update({placeholders['num_features_nonzero']: AXfeatures[1].shape})
    return feed_dict


def loadRedditFull():
    data = reddit.RedditDataset()

    adj = data.graph.adjacency_matrix_scipy(return_edge_ids=False)

    features = data.features

    y_train = data.labels[data.train_mask]
    y_test = data.labels[data.test_mask]
    y_val = data.labels[data.val_mask]

    train_index = np.arange(232965)[data.train_mask]
    test_index = np.arange(232965)[data.test_mask]
    val_index = np.arange(232965)[data.val_mask]

    return adj, features, y_train, y_val, y_test, train_index, val_index, test_index


def loadBnchmrkData(choice):
    if choice == 0:  # num_classes = 5
        data = gnnbnch.Coauthor('physics')
    elif choice == 1:  # num_classes = 15
        data = gnnbnch.Coauthor('cs')
    elif choice == 2:  # num_classes = 67
        data = gnnbnch.AmazonCoBuy('computers')
    else:  # num_classes = 8
        data = gnnbnch.AmazonCoBuy('photo')

    g = data[0]
    adj = g.adjacency_matrix_scipy(return_edge_ids=False)
    features = g.ndata['feat']
    labels = g.ndata['label']

    g = transform.remove_self_loop(g)
    N = g.number_of_nodes()

    num_train = int(np.round(1 * N / 10))
    num_val = int(np.round(N / 5))
    ind = np.random.permutation(N)

    ind_train = ind[0:num_train]
    ind_val = ind[num_train:num_train + num_val]
    ind_test = ind[num_train + num_val:]

    train_labels = labels[ind_train]
    test_labels = labels[ind_test]
    val_labels = labels[ind_val]

    return adj, features, train_labels, val_labels, test_labels, ind_train, ind_val, ind_test


def loadSNAPData(choice=5):
    SNAP_edge_files = ['git_web_ml/musae_git_edges.csv', 'twitch/musae_DE_edges.csv', 'twitch/musae_FR_edges.csv',
                       'wikipedia/musae_crocodile_edges.csv', 'wikipedia/musae_squirrel_edges.csv']
    SNAP_features_files = ['git_web_ml/musae_git_features.json', 'twitch/musae_DE_features.json',
                           'twitch/musae_FR_features.json', 'wikipedia/musae_crocodile_features.json',
                           'wikipedia/musae_squirrel_features.json']
    SNAP_label_files = ['git_web_ml/musae_git_target.csv', 'twitch/musae_DE_target.csv', 'twitch/musae_FR_target.csv',
                        'wikipedia/musae_crocodile_target.csv', 'wikipedia/musae_squirrel_target.csv']
    SNAP_dataset_name = ['git', 'Twitch_DE', 'Twitch_FR', 'crocs', 'squirrels']

    edge_filename = SNAP_edge_files[choice - 5]
    label_filename = SNAP_label_files[choice - 5]
    feature_filename = SNAP_features_files[choice - 5]
    filename = SNAP_dataset_name[choice - 5]

    data = snap_data_loader(edge_filename, label_filename, feature_filename, filename)

    N = data.N
    Ne = data.Ne
    features = data.X
    labels = np.array(data.L).astype(int)

    num_train = int(np.round(1 * N / 10))
    num_val = int(np.round(1 * N / 5))

    ind = np.random.permutation(N)
    ind_train = ind[0:num_train]
    ind_val = ind[num_train:num_train + num_val]
    ind_test = ind[num_train + num_val:]

    train_labels = labels[ind_train]
    test_labels = labels[ind_test]
    val_labels = labels[ind_val]

    g = DGLGraph()
    g.from_scipy_sparse_matrix(csr_matrix(data.A))
    g = transform.remove_self_loop(g)
    adj = g.adjacency_matrix_scipy(return_edge_ids=False)

    # g_sp, Ne_sp = generate_spare_graph(g,N, choice)

    return adj, features, train_labels, val_labels, test_labels, ind_train, ind_val, ind_test


def main(rank1, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index):
    nc = [5, 15, 10, 8, 41, 2]

    # config = tf.ConfigProto(device_count={"CPU": 4}, # limit to num_cpu_core CPU usage
    #                 inter_op_parallelism_threads = 1,
    #                 intra_op_parallelism_threads = 4,
    #                 log_device_placement=False)
    #    adj, features, y_train, y_val, y_test,train_index, val_index, test_index = loadRedditFromNPZ("data/")
    #    adj = adj+adj.T

    # loading full reddit data:

    #    adj, features, y_train, y_val, y_test,train_index, val_index, test_index = loadRedditFull()

    #    num_classes = 41
    num_classes = nc[choice]

    y_train = transferLabel2Onehot(y_train, num_classes)
    y_val = transferLabel2Onehot(y_val, num_classes)
    y_test = transferLabel2Onehot(y_test, num_classes)

    print("Sizes of data split:", y_train.shape[0], y_val.shape[0], y_test.shape[0])

    features = sp.lil_matrix(features)

    adj_train = adj[train_index, :][:, train_index]

    numNode_train = adj_train.shape[0]

    # print("numNode", numNode)

    if FLAGS.model == 'gcn_mix':
        normADJ_train = nontuple_preprocess_adj(adj_train)
        normADJ = nontuple_preprocess_adj(adj)
        # normADJ_val = nontuple_preprocess_adj(adj_val)
        # normADJ_test = nontuple_preprocess_adj(adj_test)

        num_supports = 2
        model_func = GCN_APPRO_Mix
    else:
        raise ValueError('Invalid argument for model: ' + str(FLAGS.model))

    # Some preprocessing
    features = nontuple_preprocess_features(features).todense()

    train_features = normADJ_train.dot(features[train_index])
    features = normADJ.dot(features)
    nonzero_feature_number = len(np.nonzero(features)[0])
    nonzero_feature_number_train = len(np.nonzero(train_features)[0])

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        'AXfeatures': tf.placeholder(tf.float32, shape=(None, features.shape[1])),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'num_features_nonzero': tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    print("creating model...")
    model = model_func(placeholders, input_dim=features.shape[-1], logging=True)

    # Initialize session
    sess = tf.Session()

    # Define model evaluation function
    def evaluate(features, support, labels, placeholders):
        t_test = time.time()
        feed_dict_val = construct_feeddict_forMixlayers(features, support, labels, placeholders)
        outs_val = sess.run([model.loss, model.accuracy, model.predict()], feed_dict=feed_dict_val)
        return outs_val[0], outs_val[1], (time.time() - t_test), outs_val[2]

    def comp_f1_score(preds, labels):
        pr_class = np.argmax(preds, axis=1)
        gt_class = np.argmax(labels, axis=1)
        return f1_score(pr_class, gt_class, average='micro')

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    cost_val = []

    p0 = column_prop(normADJ_train)

    # testSupport = [sparse_to_tuple(normADJ), sparse_to_tuple(normADJ)]
    valSupport = sparse_to_tuple(normADJ[val_index, :])
    testSupport = sparse_to_tuple(normADJ[test_index, :])

    t = time.time()
    maxACC = 0.0
    # Train model
    print("Starting epochs")
    total_edges_sampled = 0
    for epoch in range(FLAGS.epochs):
        t1 = time.time()

        n = 0
        edges_sampled = 0
        for batch in iterate_minibatches_listinputs([normADJ_train, y_train], batchsize=256, shuffle=True):
            [normADJ_batch, y_train_batch] = batch

            # p1 = column_prop(normADJ_batch)
            if rank1 is None:
                support1 = sparse_to_tuple(normADJ_batch)
                features_inputs = train_features
            else:
                distr = np.nonzero(np.sum(normADJ_batch, axis=0))[1]
                if rank1 > len(distr):
                    q1 = distr
                else:
                    q1 = np.random.choice(distr, rank1, replace=False, p=p0[distr] / sum(p0[distr]))  # top layer

                # q1 = np.random.choice(np.arange(numNode_train), rank1, p=p0)  # top layer

                support1 = sparse_to_tuple(normADJ_batch[:, q1].dot(sp.diags(1.0 / (p0[q1] * rank1))))
                if len(support1[1]) == 0:
                    continue

                features_inputs = train_features[q1, :]  # selected nodes for approximation
            edges_sampled += len(support1[1])
            # Construct feed dictionary
            feed_dict = construct_feeddict_forMixlayers(features_inputs, support1, y_train_batch,
                                                        placeholders)
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            # Training step
            outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
            n = n + 1

        t2 = time.time()
        # Validation
        cost, acc, duration, pr_eval = evaluate(features, valSupport, y_val, placeholders)
        cost_val.append(cost)

        if epoch > 20 and acc > maxACC:
            maxACC = acc
            saver.save(sess, "tmp/tmp_MixModel_sampleA_full.ckpt")

        #         Print results

        print("Epoch:", '%04d' % (epoch + 1),
              "train_acc=", "{:.5f}".format(outs[2]),
              "val_acc=", "{:.5f}".format(acc),
              "time/Epoch=", "{:.5f}".format((t2 - t1)),
              "Edges/epoch", "{:.1f}".format(edges_sampled + adj_train.nnz / 2),
              "Edges in dense layer", "{:.1f}".format(adj_train.nnz / 2),
              "Number of batches", "{:.1f}".format(n),
              "#edges/batch ", "{:.1f}".format(edges_sampled / n))

        if epoch % 5 == 0:
            # Validation
            test_cost, test_acc, test_duration, pr_val = evaluate(features, testSupport, y_test,
                                                                  placeholders)
        #            print("training time by far=", "{:.5f}".format(time.time() - t),
        #                  "epoch = {}".format(epoch + 1),
        #                  "cost=", "{:.5f}".format(test_cost),
        #                  "accuracy=", "{:.5f}".format(test_acc))

        if epoch > FLAGS.early_stopping and np.mean(cost_val[-2:]) > np.mean(cost_val[-(FLAGS.early_stopping + 1):-1]):
            # print("Early stopping...")
            break

        total_edges_sampled += edges_sampled

    train_duration = time.time() - t
    # Testing
    if os.path.exists("tmp/tmp_MixModel_sampleA_full.ckpt.index"):
        saver.restore(sess, "tmp/tmp_MixModel_sampleA_full.ckpt")
    test_cost, test_acc, test_duration, pr_test = evaluate(features, testSupport, y_test,
                                                           placeholders)

    test_f1 = comp_f1_score(pr_test, y_test)
    #
    return rank1, test_cost, test_acc, train_duration, epoch + 1, test_duration, total_edges_sampled + adj_train.nnz * (
                epoch + 1) / 2, test_f1


def transferG2ADJ():
    G = json_graph.node_link_graph(json.load(open("reddit/reddit-G.json")))
    feat_id_map = json.load(open("reddit/reddit-id_map.json"))
    feat_id_map = {id: val for id, val in feat_id_map.iteritems()}
    numNode = len(feat_id_map)
    adj = np.zeros((numNode, numNode))
    newEdges0 = [feat_id_map[edge[0]] for edge in G.edges()]
    newEdges1 = [feat_id_map[edge[1]] for edge in G.edges()]

    # for edge in G.edges():
    #     adj[feat_id_map[edge[0]], feat_id_map[edge[1]]] = 1
    adj = sp.csr_matrix((np.ones((len(newEdges0),)), (newEdges0, newEdges1)), shape=(numNode, numNode))
    sp.save_npz("reddit_adj.npz", adj)


def test(rank1=None):
    # config = tf.ConfigProto(device_count={"CPU": 4}, # limit to num_cpu_core CPU usage
    #                 inter_op_parallelism_threads = 1,
    #                 intra_op_parallelism_threads = 4,
    #                 log_device_placement=False)
    adj, features, y_train, y_val, y_test, train_index, val_index, test_index = loadRedditFromNPZ("data/")
    adj = adj + adj.T

    y_train = transferLabel2Onehot(y_train, 41)
    y_test = transferLabel2Onehot(y_test, 41)

    features = sp.lil_matrix(features)

    numNode_train = y_train.shape[0]

    # print("numNode", numNode)

    if FLAGS.model == 'gcn_mix':
        normADJ = nontuple_preprocess_adj(adj)
        normADJ_test = normADJ[test_index, :]
        # normADJ_val = nontuple_preprocess_adj(adj_val)
        # normADJ_test = nontuple_preprocess_adj(adj_test)

        num_supports = 2
        model_func = GCN_APPRO_Mix
    else:
        raise ValueError('Invalid argument for model: ' + str(FLAGS.model))

    # Some preprocessing
    features = nontuple_preprocess_features(features).todense()

    features = normADJ.dot(features)

    # Define placeholders
    placeholders = {
        'support': tf.sparse_placeholder(tf.float32),
        'AXfeatures': tf.placeholder(tf.float32, shape=(None, features.shape[1])),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'dropout': tf.placeholder_with_default(0., shape=()),
        'num_features_nonzero': tf.placeholder(tf.int32)  # helper variable for sparse dropout
    }

    # Create model
    model = model_func(placeholders, input_dim=features.shape[-1], logging=True)

    # Initialize session
    sess = tf.Session()

    # Define model evaluation function
    def evaluate(features, support, labels, placeholders):
        t_test = time.time()
        feed_dict_val = construct_feeddict_forMixlayers(features, support, labels, placeholders)
        outs_val = sess.run([model.loss, model.accuracy], feed_dict=feed_dict_val)
        return outs_val[0], outs_val[1], (time.time() - t_test)

    # Init variables
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    saver.restore(sess, "tmp/tmp_MixModel_sampleA.ckpt")

    cost_val = []

    p0 = column_prop(normADJ_test)

    t = time.time()

    if rank1 is None:
        support1 = sparse_to_tuple(normADJ_test)
        features_inputs = features
    else:
        distr = np.nonzero(np.sum(normADJ_test, axis=0))[1]
        if rank1 > len(distr):
            q1 = distr
        else:
            q1 = np.random.choice(distr, rank1, replace=False, p=p0[distr] / sum(p0[distr]))  # top layer

        # q1 = np.random.choice(np.arange(numNode_train), rank1, p=p0)  # top layer

        support1 = sparse_to_tuple(normADJ_test[:, q1].dot(sp.diags(1.0 / (p0[q1] * rank1))))

        features_inputs = features[q1, :]  # selected nodes for approximation

    test_cost, test_acc, test_duration = evaluate(features_inputs, support1, y_test,
                                                  placeholders)

    test_duration = time.time() - t
    print("rank1 = {}".format(rank1), "cost=", "{:.5f}".format(test_cost),
          "accuracy=", "{:.5f}".format(test_acc),
          "test time=", "{:.5f}".format(test_duration))


if __name__ == "__main__":
    num_trials = 3

    files = ["Phy", "CS", "Comp", "Photo"]

    avg_f1_scores_400 = []
    avg_f1_scores_800 = []

    std_f1_scores_400 = []
    std_f1_scores_800 = []

    for choice in range(5, 6):
        if choice < 4:
            print("loading data :", files[choice])
            adj, features, y_train, y_val, y_test, train_index, val_index, test_index = loadBnchmrkData(choice)

            test_acc_400_trials = []
            train_dur_400_trials = []
            test_acc_800_trials = []
            train_dur_800_trials = []

            for trial in range(num_trials):
                rank1_400, test_cost_400, test_acc_400, train_duration_400, epoch_400, test_duration_400, te_400, f1_400 = main(
                    400, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index)
                rank1_800, test_cost_800, test_acc_800, train_duration_800, epoch_800, test_duration_800, te_800, f1_800 = main(
                    800, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index)

                test_acc_400_trials.append(f1_400)
                train_dur_400_trials.append(train_duration_400 / epoch_400)
                test_acc_800_trials.append(f1_800)
                train_dur_800_trials.append(train_duration_800 / epoch_800)

            print("Dataset:", files[choice])
            print("Avg test accuracy, std, avg_train time, std,  400: ", np.mean(test_acc_400_trials),
                  np.std(test_acc_400_trials), np.mean(train_dur_400_trials), np.std(train_dur_400_trials))
            print("Avg test accuracy, avg_train time 800: ", np.mean(test_acc_800_trials), np.std(test_acc_800_trials),
                  np.mean(train_dur_800_trials), np.std(train_dur_800_trials))



        elif choice == 4:
            print("loading full reddit data:")
            adj, features, y_train, y_val, y_test, train_index, val_index, test_index = loadRedditFull()
            test_acc_400_trials = []
            train_dur_400_trials = []
            test_acc_800_trials = []

            train_dur_800_trials = []

            for trial in range(num_trials):
                rank1_400, test_cost_400, test_acc_400, train_duration_400, epoch_400, test_duration_400, te_400, f1_400 = main(
                    400, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index)
                rank1_800, test_cost_800, test_acc_800, train_duration_800, epoch_800, test_duration_800, te_800, f1_800 = main(
                    800, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index)

                test_acc_400_trials.append(f1_400)
                train_dur_400_trials.append(train_duration_400 / epoch_400)
                test_acc_800_trials.append(f1_800)
                train_dur_800_trials.append(train_duration_800 / epoch_800)

            print("reddit data")
            print("Avg test accuracy, std, avg_train time, std,  400: ", np.mean(test_acc_400_trials),
                  np.std(test_acc_400_trials), np.mean(train_dur_400_trials), np.std(train_dur_400_trials))
            print("Avg test accuracy, avg_train time 800: ", np.mean(test_acc_800_trials), np.std(test_acc_800_trials),
                  np.mean(train_dur_800_trials), np.std(train_dur_800_trials))


        else:
            print("loading Github data:")
            adj, features, y_train, y_val, y_test, train_index, val_index, test_index = loadSNAPData(choice)

            test_acc_400_trials = []
            train_dur_400_trials = []
            test_acc_800_trials = []

            train_dur_800_trials = []

            for trial in range(num_trials):
                rank1_400, test_cost_400, test_acc_400, train_duration_400, epoch_400, test_duration_400, te_400, f1_400 = main(
                    400, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index)
                rank1_800, test_cost_800, test_acc_800, train_duration_800, epoch_800, test_duration_800, te_800, f1_800 = main(
                    800, choice, adj, features, y_train, y_val, y_test, train_index, val_index, test_index)

                test_acc_400_trials.append(f1_400)
                train_dur_400_trials.append(train_duration_400 / epoch_400)
                test_acc_800_trials.append(f1_800)
                train_dur_800_trials.append(train_duration_800 / epoch_800)

            print("github data")
            print("Avg test accuracy, std, avg_train time, std,  400: ", np.mean(test_acc_400_trials),
                  np.std(test_acc_400_trials), np.mean(train_dur_400_trials), np.std(train_dur_400_trials))
            print("Avg test accuracy, avg_train time 800: ", np.mean(test_acc_800_trials), np.std(test_acc_800_trials),
                  np.mean(train_dur_800_trials), np.std(train_dur_800_trials))

        avg_f1_scores_400.append(np.mean(test_acc_400_trials))
        std_f1_scores_400.append(np.std(test_acc_400_trials))
        avg_f1_scores_800.append(np.mean(test_acc_800_trials))
        std_f1_scores_800.append(np.std(test_acc_800_trials))

#    np.savez("FastGCN_f1_score",avg_f1_scores_400, std_f1_scores_400 ,avg_f1_scores_800, std_f1_scores_800)
