# Copied and edited from https://github.com/Fsoft-AIC/Batch-Ollivier-Ricci-Flow
# Original author: Khang Nguyen and Tan Nguyen and Hieu Nong and Vinh Nguyen and Nhat Ho and Stanley Osher
# Description: This class implements BORF as described in [Revisiting Over-smoothing and Over-squashing using Ollivier-Ricci Curvature, 2023]
import os
from typing import Any

import torch
import pathlib
import numpy as np
from torch_geometric.transforms import BaseTransform

from torch_geometric.utils import to_networkx, from_networkx

# Copied from https://github.com/Fsoft-AIC/Batch-Ollivier-Ricci-Flow
# Original author: Khang Nguyen and Tan Nguyen and Hieu Nong and Vinh Nguyen and Nhat Ho and Stanley Osher
# Description: This class implements G2 as described in [Revisiting Over-smoothing and Over-squashing using Ollivier-Ricci Curvature, 2023]
import heapq
import importlib
import math
import time
import torch
import pandas as pd

from src.utils.path_io import get_path_up_to

torch.multiprocessing.set_start_method('spawn')
_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import multiprocessing as mp
from functools import lru_cache

import networkit as nk
import networkx as nx
import numpy as np
import ot

# from .util import logger, set_verbose, cut_graph_by_cutoff, get_rf_metric_cutoff

ROOT_PATH = get_path_up_to(__file__, "src")
EPSILON = 1e-7  # to prevent divided by zero

# ---Shared global variables for multiprocessing used.---
# _Gk = nk.graph.Graph()
_alpha = 0.5
_weight = "weight"
_method = "OTDSinkhornMix"
_base = math.e
_exp_power = 2
_proc = mp.cpu_count()
_cache_maxsize = 1000
_shortest_path = "all_pairs"
_nbr_topk = 3000
_OTDSinkhorn_threshold = 2000
_apsp = {}


# -------------------------------------------------------
@lru_cache(_cache_maxsize)
def _get_single_node_neighbors_distributions(node, direction="successors"):
    """Get the neighbor density distribution of given node `node`.

    Parameters
    ----------
    node : int
        Node index in Networkit graph `_Gk`.
    direction : {"predecessors", "successors"}
        Direction of neighbors in directed graph. (Default value: "successors")

    Returns
    -------
    distributions : lists of float
        Density distributions of neighbors up to top `_nbr_topk` nodes.
    nbrs : lists of int
        Neighbor index up to top `_nbr_topk` nodes.

    """
    if _Gk.isDirected():
        if direction == "predecessors":
            neighbors = list(_Gk.iterInNeighbors(node))
        else:  # successors
            neighbors = list(_Gk.iterNeighbors(node))
    else:
        neighbors = list(_Gk.iterNeighbors(node))

    # Get sum of distributions from x's all neighbors
    heap_weight_node_pair = []
    for nbr in neighbors:
        if direction == "predecessors":
            w = _base ** (-_Gk.weight(nbr, node) ** _exp_power)
        else:  # successors
            w = _base ** (-_Gk.weight(node, nbr) ** _exp_power)

        if len(heap_weight_node_pair) < _nbr_topk:
            heapq.heappush(heap_weight_node_pair, (w, nbr))
        else:
            heapq.heappushpop(heap_weight_node_pair, (w, nbr))

    nbr_edge_weight_sum = sum([x[0] for x in heap_weight_node_pair])

    if not neighbors:
        # No neighbor, all mass stay at node
        return [1], [node]

    if nbr_edge_weight_sum > EPSILON:
        # Sum need to be not too small to prevent divided by zero
        distributions = [(1.0 - _alpha) * w / nbr_edge_weight_sum for w, _ in heap_weight_node_pair]
    else:
        # Sum too small, just evenly distribute to every neighbors
        # print("[OLLIVIER RICCI]: Neighbor weight sum too small, list:", heap_weight_node_pair)
        distributions = [(1.0 - _alpha) / len(heap_weight_node_pair)] * len(heap_weight_node_pair)

    nbr = [x[1] for x in heap_weight_node_pair]
    return distributions + [_alpha], nbr + [node]


def _get_all_pairs_shortest_path():
    """Pre-compute all pairs shortest paths of the assigned graph `_Gk`."""
    print("[OLLIVIER RICCI]: Start to compute all pair shortest path.")

    global _Gk

    t0 = time.time()
    apsp = nk.distance.APSP(_Gk).run().getDistances()
    print("[OLLIVIER RICCI]: %8f secs for all pair by NetworKit." % (time.time() - t0))
    #
    return np.array(apsp)


@lru_cache(_cache_maxsize)
def _source_target_shortest_path(source, target):
    """Compute pairwise shortest path from `source` to `target` by BidirectionalDijkstra via Networkit.

    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.

    Returns
    -------
    length : float
        Pairwise shortest path length.

    """

    length = nk.distance.BidirectionalDijkstra(_Gk, source, target).run().getDistance()
    # assert length < 1e300, "Shortest path between %d, %d is not found" % (source, target)
    return length


def _distribute_densities(source, target):
    """Get the density distributions of source and target node, and the cost (all pair shortest paths) between
    all source's and target's neighbors. Notice that only neighbors with top `_nbr_topk` edge weights.

    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.
    Returns
    -------
    x : (m,) numpy.ndarray
        Source's density distributions, includes source and source's neighbors.
    y : (n,) numpy.ndarray
        Target's density distributions, includes source and source's neighbors.
    d : (m, n) numpy.ndarray
        Shortest path matrix.

    """

    # Distribute densities for source and source's neighbors as x
    t0 = time.time()

    if _Gk.isDirected():
        x, source_topknbr = _get_single_node_neighbors_distributions(source, "predecessors")
    else:
        x, source_topknbr = _get_single_node_neighbors_distributions(source, "successors")

    # Distribute densities for target and target's neighbors as y
    y, target_topknbr = _get_single_node_neighbors_distributions(target, "successors")

    # print("[OLLIVIER RICCI]: %8f secs density distribution for edge." % (time.time() - t0))

    # construct the cost dictionary from x to y
    t0 = time.time()

    if _shortest_path == "pairwise":
        d = []
        for src in source_topknbr:
            tmp = []
            for tgt in target_topknbr:
                tmp.append(_source_target_shortest_path(src, tgt))
            d.append(tmp)
        d = np.array(d)
    else:  # all_pairs
        d = _apsp[np.ix_(source_topknbr, target_topknbr)]  # transportation matrix

    x = np.array(x)  # the mass that source neighborhood initially owned
    y = np.array(y)  # the mass that target neighborhood needs to received

    # print("[OLLIVIER RICCI]: %8f secs density matrix construction for edge." % (time.time() - t0))

    return x, y, source_topknbr, target_topknbr, d


def _compute_ricci_curvature_single_edge(source, target):
    """Ricci curvature computation for a given single edge.

    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.

    Returns
    -------
    result : dict[(int,int), float]
        The Ricci curvature of given edge in dict format. E.g.: {(node1, node2): ricciCurvature}

    """
    # logger.debug("EDGE:%s,%s"%(source,target))
    # assert source != target, "Self loop is not allowed."  # to prevent self loop

    # If the weight of edge is too small, return 0 instead.
    if _Gk.weight(source, target) < EPSILON:
        # print("[OLLIVIER RICCI]: Zero weight edge detected for edge (%s,%s), return Ricci Curvature as 0 instead." %
        #              (source, target))
        return {(source, target): 0}

    # compute transportation distance
    m = 1  # assign an initial cost
    # assert _method in ["OTD", "ATD", "Sinkhorn", "OTDSinkhornMix"], \
    #     'Method %s not found, support method:["OTD", "ATD", "Sinkhorn", "OTDSinkhornMix]' % _method

    x, y, neighbors_x, neighbors_y, d = _distribute_densities(source, target)
    optimal_plan = ot.emd(x, y, d)
    optimal_cost = optimal_plan * d
    optimal_total_cost = np.sum(optimal_cost)
    optimal_cost = pd.DataFrame(optimal_cost, columns=neighbors_y, index=neighbors_x)

    '''
    if _method == "OTD":
        x, y, d = _distribute_densities(source, target)
        m = _optimal_transportation_distance(x, y, d)
    elif _method == "ATD":
        m = _average_transportation_distance(source, target)
    elif _method == "Sinkhorn":
        x, y, d = _distribute_densities(source, target)
        m = _sinkhorn_distance(x, y, d)
    elif _method == "OTDSinkhornMix":
        x, y, d = _distribute_densities(source, target)
        # When x and y are small (usually around 2000 to 3000), ot.emd2 is way faster than ot.sinkhorn2
        # So we only do sinkhorn when both x and y are too large for ot.emd2
        if len(x) > _OTDSinkhorn_threshold and len(y) > _OTDSinkhorn_threshold:
            m = _sinkhorn_distance(x, y, d)
        else:
            m = _optimal_transportation_distance(x, y, d)
    '''

    # compute Ricci curvature: k=1-(m_{x,y})/d(x,y)
    result = 1 - (optimal_total_cost / _Gk.weight(source, target))  # Divided by the length of d(i, j)

    # Instead just safe optimal cost edge
    p, q = np.unravel_index(optimal_cost.values.argmax(), optimal_cost.values.shape)
    max_optimal_cost = optimal_cost.index[p], optimal_cost.columns[q]

    # Avoid storing the whole transportation plan
    optimal_cost = None

    # print("[OLLIVIER RICCI]: Ricci curvature (%s,%s) = %f" % (source, target, result))

    return {
        (source, target): {
            'rc_curvature': result,
            'rc_transport_cost': optimal_cost,
            'rc_max_cost_edge': max_optimal_cost
        }
    }


def _wrap_compute_single_edge(stuff):
    """Wrapper for args in multiprocessing."""
    return _compute_ricci_curvature_single_edge(*stuff)


def _compute_ricci_curvature_edges(G: nx.Graph, weight="weight", edge_list=[],
                                   alpha=0.5, method="OTDSinkhornMix",
                                   base=math.e, exp_power=2, proc=mp.cpu_count(), chunksize=None,
                                   cache_maxsize=1000,
                                   shortest_path="all_pairs", nbr_topk=3000):
    """Compute Ricci curvature for edges in  given edge lists.

    Parameters
    ----------
    G : NetworkX graph
        A given directional or undirectional NetworkX graph.
    weight : str
        The edge weight used to compute Ricci curvature. (Default value = "weight")
    edge_list : list of edges
        The list of edges to compute Ricci curvature, set to [] to run for all edges in G. (Default value = [])
    alpha : float
        The parameter for the discrete Ricci curvature, range from 0 ~ 1.
        It means the share of mass to leave on the original node.
        E.g. x -> y, alpha = 0.4 means 0.4 for x, 0.6 to evenly spread to x's nbr.
        (Default value = 0.5)
    method : {"OTD", "ATD", "Sinkhorn"}
        The optimal transportation distance computation method. (Default value = "OTDSinkhornMix")

        Transportation method:
            - "OTD" for Optimal Transportation Distance,
            - "ATD" for Average Transportation Distance.
            - "Sinkhorn" for OTD approximated Sinkhorn distance.
            - "OTDSinkhornMix" use OTD for nodes of edge with less than _OTDSinkhorn_threshold(default 2000) neighbors,
            use Sinkhorn for faster computation with nodes of edge more neighbors. (OTD is faster for smaller cases)
    base : float
        Base variable for weight distribution. (Default value = `math.e`)
    exp_power : float
        Exponential power for weight distribution. (Default value = 0)
    proc : int
        Number of processor used for multiprocessing. (Default value = `cpu_count()`)
    chunksize : int
        Chunk size for multiprocessing, set None for auto decide. (Default value = `None`)
    cache_maxsize : int
        Max size for LRU cache for pairwise shortest path computation.
        Set this to `None` for unlimited cache. (Default value = 1000000)
    shortest_path : {"all_pairs","pairwise"}
        Method to compute shortest path. (Default value = `all_pairs`)
    nbr_topk : int
        Only take the top k edge weight neighbors for density distribution.
        Smaller k run faster but the result is less accurate. (Default value = 3000)

    Returns
    -------
    output : dict[(int,int), float]
        A dictionary of edge Ricci curvature. E.g.: {(node1, node2): ricciCurvature}.

    """

    # print("[OLLIVIER RICCI]: Number of nodes: %d" % G.number_of_nodes())
    # print("[OLLIVIER RICCI]: Number of edges: %d" % G.number_of_edges())

    if not nx.get_edge_attributes(G, weight):
        # print('Edge weight not detected in graph, use "weight" as default edge weight.')
        for (v1, v2) in G.edges():
            G[v1][v2][weight] = 1.0

    # ---set to global variable for multiprocessing used.---
    global _Gk
    global _alpha
    global _weight
    global _method
    global _base
    global _exp_power
    global _proc
    global _cache_maxsize
    global _shortest_path
    global _nbr_topk
    global _apsp
    # -------------------------------------------------------

    _Gk = nk.nxadapter.nx2nk(G, weightAttr=weight)
    _alpha = alpha
    _weight = weight
    _method = method
    _base = base
    _exp_power = exp_power
    _proc = proc
    _cache_maxsize = cache_maxsize
    _shortest_path = shortest_path
    _nbr_topk = nbr_topk

    # Construct nx to nk dictionary
    nx2nk_ndict, nk2nx_ndict = {}, {}
    for idx, n in enumerate(G.nodes()):
        nx2nk_ndict[n] = idx
        nk2nx_ndict[idx] = n

    if _shortest_path == "all_pairs":
        # Construct the all pair shortest path dictionary
        # if not _apsp:
        _apsp = _get_all_pairs_shortest_path()

    if edge_list:
        args = [(nx2nk_ndict[source], nx2nk_ndict[target]) for source, target in edge_list]
    else:
        args = [(nx2nk_ndict[source], nx2nk_ndict[target]) for source, target in G.edges()]

    # Start compute edge Ricci curvature
    t0 = time.time()

    with mp.get_context('fork').Pool(processes=_proc) as pool:
        # WARNING: Now only fork works, spawn will hang.

        # Decide chunksize following method in map_async
        if chunksize is None:
            chunksize, extra = divmod(len(args), proc * 4)
            if extra:
                chunksize += 1
            if chunksize == 0: chunksize = 1

        # Compute Ricci curvature for edges
        result = pool.imap_unordered(_wrap_compute_single_edge, args, chunksize=chunksize)
        pool.close()
        pool.join()

    # Convert edge index from nk back to nx for final output
    output = {}
    for rc in result:
        for k in list(rc.keys()):
            output[(nk2nx_ndict[k[0]], nk2nx_ndict[k[1]])] = rc[k]

    # print("[OLLIVIER RICCI]: %8f secs for Ricci curvature computation." % (time.time() - t0))

    return output


def _compute_ricci_curvature(G: nx.Graph, weight="weight", **kwargs):
    """Compute Ricci curvature of edges and nodes.
    The node Ricci curvature is defined as the average of node's adjacency edges.

    Parameters
    ----------
    G : NetworkX graph
        A given directional or undirectional NetworkX graph.
    weight : str
        The edge weight used to compute Ricci curvature. (Default value = "weight")
    **kwargs
        Additional keyword arguments passed to `_compute_ricci_curvature_edges`.

    Returns
    -------
    G: NetworkX graph
        A NetworkX graph with "ricciCurvature" on nodes and edges.
    """

    # compute Ricci curvature for all edges
    edge_ricci = _compute_ricci_curvature_edges(G, weight=weight, **kwargs)

    _apsp = None
    try:
        del _apsp
    except Exception:
        pass


    # Assign edge Ricci curvature from result to graph G
    nx.set_edge_attributes(G, edge_ricci, "ricciCurvature")

    # Compute node Ricci curvature
    for n in G.nodes():
        rc_sum = 0  # sum of the neighbor Ricci curvature
        if G.degree(n) != 0:
            for nbr in G.neighbors(n):
                if 'ricciCurvature' in G[n][nbr]:
                    rc_sum += G[n][nbr]['ricciCurvature']['rc_curvature']

            # Assign the node Ricci curvature to be the average of node's adjacency edges
            G.nodes[n]['ricciCurvature'] = rc_sum / G.degree(n)
            # print("[OLLIVIER RICCI]: node %s, Ricci Curvature = %f" % (n, G.nodes[n]['ricciCurvature']))

    return G


class OllivierRicci:
    """A class to compute Ollivier-Ricci curvature for all nodes and edges in G.
    Node Ricci curvature is defined as the average of all it's adjacency edge.

    """

    def __init__(self, G: nx.Graph,
                 weight="weight",
                 alpha=0.5,
                 method="OTDSinkhornMix",
                 base=math.e,
                 exp_power=2,
                 proc=mp.cpu_count(),
                 chunksize=None,
                 shortest_path="all_pairs",
                 cache_maxsize=1000,
                 nbr_topk=3000,
                 verbose="ERROR"):
        """Initialized a container to compute Ollivier-Ricci curvature/flow.

        Parameters
        ----------
        G : NetworkX graph
            A given directional or undirectional NetworkX graph.
        weight : str
            The edge weight used to compute Ricci curvature. (Default value = "weight")
        alpha : float
            The parameter for the discrete Ricci curvature, range from 0 ~ 1.
            It means the share of mass to leave on the original node.
            E.g. x -> y, alpha = 0.4 means 0.4 for x, 0.6 to evenly spread to x's nbr.
            (Default value = 0.5)
        method : {"OTD", "ATD", "Sinkhorn"}
            The optimal transportation distance computation method. (Default value = "OTDSinkhornMix")

            Transportation method:
                - "OTD" for Optimal Transportation Distance,
                - "ATD" for Average Transportation Distance.
                - "Sinkhorn" for OTD approximated Sinkhorn distance.
                - "OTDSinkhornMix" use OTD for nodes of edge with less than _OTDSinkhorn_threshold(default 2000) neighbors,
                use Sinkhorn for faster computation with nodes of edge more neighbors. (OTD is faster for smaller cases)
        base : float
            Base variable for weight distribution. (Default value = `math.e`)
        exp_power : float
            Exponential power for weight distribution. (Default value = 2)
        proc : int
            Number of processor used for multiprocessing. (Default value = `cpu_count()`)
        chunksize : int
            Chunk size for multiprocessing, set None for auto decide. (Default value = `None`)
        shortest_path : {"all_pairs","pairwise"}
            Method to compute shortest path. (Default value = `all_pairs`)
        cache_maxsize : int
            Max size for LRU cache for pairwise shortest path computation.
            Set this to `None` for unlimited cache. (Default value = 1000000)
        nbr_topk : int
            Only take the top k edge weight neighbors for density distribution.
            Smaller k run faster but the result is less accurate. (Default value = 3000)
        verbose : {"INFO", "TRACE","DEBUG","ERROR"}
            Verbose level. (Default value = "ERROR")
                - "INFO": show only iteration process log.
                - "TRACE": show detailed iteration process log.
                - "DEBUG": show all output logs.
                - "ERROR": only show log if error happened.

        """
        self.G = G.copy()
        self.alpha = alpha
        self.weight = weight
        self.method = method
        self.base = base
        self.exp_power = exp_power
        self.proc = proc
        self.chunksize = chunksize
        self.cache_maxsize = cache_maxsize
        self.shortest_path = shortest_path
        self.nbr_topk = nbr_topk

        self.lengths = {}  # all pair shortest path dictionary
        self.densities = {}  # density distribution dictionary

        # assert importlib.util.find_spec("ot"), \
        #     "Package POT: Python Optimal Transport is required for Sinkhorn distance."

        if not nx.get_edge_attributes(self.G, weight):
            # print('Edge weight not detected in graph, use "weight" as default edge weight.')
            for (v1, v2) in self.G.edges():
                self.G[v1][v2][weight] = 1.0

        self_loop_edges = list(nx.selfloop_edges(self.G))
        if self_loop_edges:
            # print('Self-loop edge detected. Removing %d self-loop edges.' % len(self_loop_edges))
            self.G.remove_edges_from(self_loop_edges)

    def compute_ricci_curvature(self):
        """Compute Ricci curvature of edges and nodes.
        The node Ricci curvature is defined as the average of node's adjacency edges.

        Returns
        -------
        G: NetworkX graph
            A NetworkX graph with "ricciCurvature" on nodes and edges.

        Examples
        --------
        To compute the Ollivier-Ricci curvature for karate club graph::

            {'weight': 1.0, 'ricciCurvature': 0.11111111071683011}
        """

        self.G = _compute_ricci_curvature(G=self.G, weight=self.weight,
                                          alpha=self.alpha, method=self.method,
                                          base=self.base, exp_power=self.exp_power,
                                          proc=self.proc, chunksize=self.chunksize, cache_maxsize=self.cache_maxsize,
                                          shortest_path=self.shortest_path,
                                          nbr_topk=self.nbr_topk)
        return self.G


def _preprocess_data(data, is_undirected=False):
    # Get necessary data information
    N = data.x.shape[0]
    m = data.edge_index.shape[1]

    # Compute the adjacency matrix
    if not "edge_type" in data.keys():
        edge_type = np.zeros(m, dtype=int)
    else:
        edge_type = data.edge_type

    # Convert graph to Networkx
    G = to_networkx(data)
    if is_undirected:
        G = G.to_undirected()

    return G, N, edge_type


class Borf3(BaseTransform):

    def __init__(self,
                 graph_name:str,
                 n_loops=10,
                 remove_edges=True,
                 removal_bound=0.5,
                 tau=1,
                 is_undirected=False,
                 batch_add=4,
                 batch_remove=2,
                 device=None,
                 shortest_path= "all_pairs",
                 debug=False):
        super().__init__()
        self.graph_index = 0
        self.n_loops = n_loops
        self.remove_edges = remove_edges
        self.removal_bound = removal_bound
        self.tau = tau
        self.is_undirected = is_undirected
        self.batch_add = batch_add
        self.batch_remove = batch_remove
        self.device = device
        self.debug = debug
        self.shortest_path = shortest_path
        self.dirname = os.path.join(ROOT_PATH, 'data','graphs', graph_name, 'borf')


    def forward(self, data: Any) -> Any:

        # Check if there is a preprocessed graph
        print(f'[BORF]: Processing graph {self.graph_index}...')
        pathlib.Path(self.dirname).mkdir(parents=True, exist_ok=True)
        edge_index_filename = os.path.join(self.dirname,
                                           f'iters_{self.n_loops}_add_{self.batch_add}_remove_{self.batch_remove}_edge_index_{self.graph_index}.pt')
        edge_type_filename = os.path.join(self.dirname,
                                          f'iters_{self.n_loops}_add_{self.batch_add}_remove_{self.batch_remove}_edge_type_{self.graph_index}.pt')

        if (os.path.exists(edge_index_filename) and os.path.exists(edge_type_filename)):
            # if (debug): print(
            #     f'[INFO] Rewired graph for {loops} iterations, {batch_add} edge additions and {batch_remove} edge removal exists...')
            edge_index = torch.load(edge_index_filename, weights_only=False)
            edge_type = torch.load(edge_type_filename, weights_only=False)

            data.edge_index = edge_index
            data.edge_type = edge_type
            self.graph_index += 1
            return data

        # Preprocess data
        G, N, edge_type = _preprocess_data(data)

        # Rewiring begins
        for _ in range(self.n_loops):
            # Compute ORC
            orc = OllivierRicci(G, alpha=0, shortest_path=self.shortest_path)
            orc.compute_ricci_curvature()
            _C = sorted(orc.G.edges, key=lambda x: orc.G[x[0]][x[1]]['ricciCurvature']['rc_curvature'])

            # Get top negative and positive curved edges
            most_pos_edges = _C[-self.batch_remove:]
            most_neg_edges = _C[:self.batch_add]

            # Add edges
            for (u, v) in most_neg_edges:
                # pi = orc.G[u][v]['ricciCurvature']['rc_transport_cost']
                # p, q = np.unravel_index(pi.values.argmax(), pi.values.shape)
                # p, q = pi.index[p], pi.columns[q]


                p, q = orc.G[u][v]['ricciCurvature']['rc_max_cost_edge']

                if (p != q and not G.has_edge(p, q)):
                    G.add_edge(p, q)

            # Remove edges
            for (u, v) in most_pos_edges:
                if (G.has_edge(u, v)):
                    G.remove_edge(u, v)

        edge_index = from_networkx(G).edge_index
        edge_type = torch.zeros(size=(len(G.edges),)).type(torch.LongTensor)
        # edge_type = torch.tensor(edge_type)

        # if (debug): print(f'[INFO] Saving edge_index to {edge_index_filename}')
        with open(edge_index_filename, 'wb') as f:
            torch.save(edge_index, f)

        # if (debug): print(f'[INFO] Saving edge_type to {edge_type_filename}')
        with open(edge_type_filename, 'wb') as f:
            torch.save(edge_type, f)

        # Apply new edges to data object
        data.edge_index = edge_index
        data.edge_type = edge_type

        self.graph_index += 1

        return data
