# Copied from https://github.com/mengsig/DomiRank

########## Here are the associated DomiRank functions #############
import networkx as nx
import numpy as np
import scipy as sp
import scipy.sparse

########## Here are the general functions needed for efficient dismantling and testing of networks #############


def get_largest_component(G, strong=False):
    """
    here we get the largest component of a graph, either from scipy.sparse or from networkX.Graph datatype.
    1. The argument changes whether or not you want to find the strong or weak - connected components of the graph
    """
    if type(G) == nx.classes.graph.Graph:  # check if it is a networkx Graph
        if nx.is_directed(G) and strong == False:
            GMask = max(nx.weakly_connected_components(G), key=len)
        if nx.is_directed(G) and strong == True:
            GMask = max(nx.strongly_connected_components(G), key=len)
        else:
            GMask = max(nx.connected_components(G), key=len)
        G = G.subgraph(GMask)
    else:
        raise TypeError("You must input a networkx.Graph Data-Type")
    return G


def relabel_nodes(G, yield_map=False):
    """relabels the nodes to be from 0, ... len(G).
    1. Yield_map returns an extra output as a dict. in case you want to save the hash-map to retrieve node-id
    """
    if yield_map == True:
        nodes = dict(zip(range(len(G)), G.nodes()))
        G = nx.relabel_nodes(G, dict(zip(G.nodes(), range(len(G)))))
        return G, nodes
    else:
        G = nx.relabel_nodes(G, dict(zip(G.nodes(), range(len(G)))))
        return G


def get_component_size(G, strong=False):
    """
    here we get the largest component of a graph, either from scipy.sparse or from networkX.Graph datatype.
    1. The argument changes whether or not you want to find the strong or weak - connected components of the graph
    """
    if type(G) == nx.classes.graph.Graph:  # check if it is a networkx Graph
        if nx.is_directed(G) and strong == False:
            GMask = max(nx.weakly_connected_components(G), key=len)
        if nx.is_directed(G) and strong == True:
            GMask = max(nx.strongly_connected_components(G), key=len)
        else:
            GMask = max(nx.connected_components(G), key=len)
        G = G.subgraph(GMask)
        return len(GMask)
    elif type(G) == scipy.sparse.csr_array:
        if strong == False:
            connection_type = "weak"
        else:
            connection_type = "strong"
        noComponent, lenComponent = sp.sparse.csgraph.connected_components(
            G, directed=True, connection=connection_type, return_labels=True
        )
        return np.bincount(lenComponent).max()
    else:
        raise TypeError(
            "You must input a networkx.Graph Data-Type or scipy.sparse.csr array"
        )


def get_link_size(G):
    if type(G) == nx.classes.graph.Graph:  # check if it is a networkx Graph
        links = len(G.edges())  # convert to scipy sparse if it is a graph
    elif type(G) == scipy.sparse.csr_array:
        links = G.sum()
    else:
        raise TypeError("You must input a networkx.Graph Data-Type")
    return links


def remove_node(G, removedNode):
    """
    removes the node from the graph by removing it from a networkx.Graph type, or zeroing the edges in array form.
    """
    if type(G) == nx.classes.graph.Graph:  # check if it is a networkx Graph
        if type(removedNode) == int:
            G.remove_node(removedNode)
        else:
            for node in removedNode:
                G.remove_node(node)  # remove node in graph form
        return G
    elif type(G) == scipy.sparse.csr_array:
        diag = sp.sparse.csr_array(sp.sparse.eye(G.shape[0]))
        diag[removedNode, removedNode] = (
            0  # set the rows and columns that are equal to zero in the sparse array
        )
        G = diag @ G
        return G @ diag


def generate_attack(centrality, node_map=False):
    """we generate an attack based on a centrality measure -
    you can possibly input the node_map to convert the attack to have the correct nodeID
    """
    if node_map == False:
        node_map = range(len(centrality))
    else:
        node_map = list(node_map.values())
    zipped = dict(zip(node_map, centrality))
    attackStrategy = sorted(zipped, reverse=True, key=zipped.get)
    return attackStrategy


def network_attack_sampled(G, attackStrategy, sampling=0):
    """Attack a network in a sampled manner... recompute links and largest component after every xth node removal, according to some -
    G: is the input graph, preferably as a sparse array.
    inputed attack strategy
    Note: if sampling is not set, it defaults to sampling every 1%, otherwise, sampling is an integer
    that is equal to the number of nodes you want to skip every time you sample.
    So for example sampling = int(len(G)/100) would sample every 1% of the nodes removed
    """
    if type(G) == nx.classes.graph.Graph:  # check if it is a networkx Graph
        GAdj = nx.to_scipy_sparse_array(G)  # convert to scipy sparse if it is a graph
    else:
        GAdj = G.copy()

    if sampling == 0:
        if GAdj.shape[0] < 100:
            sampling = 1
        else:
            sampling = int(GAdj.shape[0] / 100)
    N = GAdj.shape[0]
    initialComponent = get_component_size(GAdj)
    initialLinks = get_link_size(GAdj)
    m = GAdj.sum() / N
    componentEvolution = np.zeros((N // sampling + 1))
    linksEvolution = np.zeros((N // sampling) + 1)
    j = 0
    for i in range(N - 1):
        if i % sampling == 0:
            if i == 0:
                componentEvolution[j] = get_component_size(GAdj) / initialComponent
                linksEvolution[j] = get_link_size(GAdj) / initialLinks
                j += 1
            else:
                GAdj = remove_node(GAdj, attackStrategy[i - sampling : i])
                componentEvolution[j] = get_component_size(GAdj) / initialComponent
                linksEvolution[j] = get_link_size(GAdj) / initialLinks
                j += 1
    return componentEvolution, linksEvolution


######## Beginning of domirank stuff! ####################


def domirank(
    G, analytical=True, sigma=-1, dt=0.1, epsilon=1e-5, maxIter=1000, checkStep=10
):
    """
        G is the input graph as a (preferably) sparse array.
        This solves the dynamical equation presented in the Paper: "DomiRank Centrality: revealing structural fragility of
    complex networks via node dominance" and yields the following output: bool, DomiRankCentrality
        Here, sigma needs to be chosen a priori.
        dt determines the step size, usually, 0.1 is sufficiently fine for most networks (could cause issues for networks
        with an extremely high degree, but has never failed me!)
        maxIter is the depth that you are searching with in case you don't converge or diverge before that.
        Checkstep is the amount of steps that you go before checking if you have converged or diverged.


        This algorithm scales with O(m) where m is the links in your sparse array.
    """
    if type(G) == nx.classes.graph.Graph:  # check if it is a networkx Graph
        G = nx.to_scipy_sparse_array(G)  # convert to scipy sparse if it is a graph
    else:
        G = G.copy()
    if analytical == False:
        if sigma == -1:
            sigma, _ = optimal_sigma(
                G,
                analytical=False,
                dt=dt,
                epsilon=epsilon,
                maxIter=maxIter,
                checkStep=checkStep,
            )
        pGAdj = sigma * G.astype(np.float64)
        Psi = np.ones(pGAdj.shape[0]).astype(np.float64) / pGAdj.shape[0]
        maxVals = np.zeros(int(maxIter / checkStep)).astype(np.float64)
        dt = np.float64(dt)
        j = 0
        boundary = epsilon * pGAdj.shape[0] * dt
        for i in range(maxIter):
            tempVal = ((pGAdj @ (1 - Psi)) - Psi) * dt
            Psi += tempVal.real
            if i % checkStep == 0:
                if np.abs(tempVal).sum() < boundary:
                    break
                maxVals[j] = tempVal.max()
                if i == 0:
                    initialChange = maxVals[j]
                if j > 0:
                    if maxVals[j] > maxVals[j - 1] and maxVals[j - 1] > maxVals[j - 2]:
                        return False, Psi
                j += 1

        return True, Psi
    else:
        if sigma == -1:
            sigma = optimal_sigma(
                G,
                analytical=True,
                dt=dt,
                epsilon=epsilon,
                maxIter=maxIter,
                checkStep=checkStep,
            )

        sigma = sigma[0] if isinstance(sigma, tuple) else sigma

        A = sigma * G + sp.sparse.identity(G.shape[0])
        B = sigma * G.sum(axis=-1)

        Psi = sp.sparse.linalg.spsolve(A, B)
        return True, Psi, sigma


def find_eigenvalue(
    G, minVal=0, maxVal=1, maxDepth=100, dt=0.1, epsilon=1e-5, maxIter=100, checkStep=10
):
    """
    G: is the input graph as a sparse array.
    Finds the largest negative eigenvalue of an adjacency matrix using the DomiRank algorithm.
    Currently this function is only single-threaded, as the bisection algorithm only allows for single-threaded
    exection. Note, that this algorithm is slightly different, as it uses the fact that DomiRank diverges
    at values larger than -1/lambN to its benefit, and thus, it is not exactly bisection theorem. I haven't
    tested in order to see which exact value is the fastest for execution, but that will be done soon!
    Some notes:
    Increase maxDepth for increased accuracy.
    Increase maxIter if DomiRank doesn't start diverging within 100 iterations -- i.e. increase at the expense of
    increased computational cost if you want potential increased accuracy.
    Decrease checkstep for increased error-finding for the values of sigma that are too large, but higher compcost
    if you are frequently less than the value (but negligible compcost).
    """
    x = (minVal + maxVal) / G.sum(axis=-1).max()
    minValStored = 0
    for i in range(maxDepth):
        if maxVal - minVal < epsilon:
            break
        if domirank(G, False, x, dt, epsilon, maxIter, checkStep)[0]:
            minVal = x
            x = (minVal + maxVal) / 2
            minValStored = minVal
        else:
            maxVal = (x + maxVal) / 2
            x = (minVal + maxVal) / 2
        # if minVal == 0:
        #     print(f'Current Interval : [-inf, -{1/maxVal}]')
        # else:
        #     print(f'Current Interval : [-{1/minVal}, -{1/maxVal}]')
    finalVal = (maxVal + minVal) / 2
    return -1 / finalVal


############## This section is for finding the optimal sigma #######################


def process_iteration(
    analytical, sigma, spArray, maxIter, checkStep, dt, epsilon, sampling
):
    tf, domiDist, _ = domirank(
        spArray,
        analytical=analytical,
        sigma=sigma,
        dt=dt,
        epsilon=epsilon,
        maxIter=maxIter,
        checkStep=checkStep,
    )
    domiAttack = generate_attack(domiDist)
    ourTempAttack, __ = network_attack_sampled(spArray, domiAttack, sampling=sampling)
    finalErrors = ourTempAttack.sum()
    return finalErrors


def optimal_sigma(
    spArray,
    analytical=True,
    endVal=0,
    startval=0.000001,
    iterationNo=100,
    dt=0.1,
    epsilon=1e-5,
    maxIter=100,
    checkStep=10,
    maxDepth=100,
    sampling=0,
):
    """This part finds the optimal sigma by searching the space, here are the novel parameters:
    spArray: is the input sparse array/matrix for the network.
    startVal: is the starting value of the space that you want to search.
    endVal: is the ending value of the space that you want to search (normally it should be the eigenvalue)
    iterationNo: the number of partitions of the space between lambN that you set

    return : the function returns the value of sigma - the numerator of the fraction of (sigma)/(-1*lambN)
    """
    if endVal == 0:
        endVal = find_eigenvalue(
            spArray,
            maxDepth=maxDepth,
            dt=dt,
            epsilon=epsilon,
            maxIter=maxIter,
            checkStep=checkStep,
        )
    endval = -0.9999 / endVal
    tempRange = np.arange(
        startval,
        endval + (endval - startval) / iterationNo,
        (endval - startval) / iterationNo,
    )
    results = []
    for sigma in tempRange:
        result = process_iteration(
            analytical, sigma, spArray, maxIter, checkStep, dt, epsilon, sampling
        )
        results.append(result)

    finalErrors = np.array(results)
    minEig = np.where(finalErrors == finalErrors.min())[0][-1]
    minEig = tempRange[minEig]
    return minEig, finalErrors


def domirank_fast(
    G: scipy.sparse.csr_array,
    sigma_numerator=-1,
    dt=0.1,
    epsilon=1e-5,
    maxIter=1000,
    checkStep=10,
    maxDepth=100,
):
    G = G.copy()
    eigenvalue = find_eigenvalue(
        G,
        maxDepth=maxDepth,
        dt=dt,
        epsilon=epsilon,
        maxIter=maxIter,
        checkStep=checkStep,
    )
    sigma = -sigma_numerator / eigenvalue

    A = sigma * G + sp.sparse.identity(G.shape[0])
    B = sigma * G.sum(axis=-1)

    Psi = sp.sparse.linalg.spsolve(A, B)
    return True, Psi, sigma
