# Example: MULTIPARAMETER on DYNAMIC NETWORKS
# -------------------------------------------------------
# This code computes the ZIGZAG persistence diagram on 
# dynamic networks (DN).
# The input can be any Dynamic_Network object as defined in 
# our library: EVDyNET
# -------------------------------------------------------
# INPUT: Dynamic_Network object as defined in our library: EVDyNET
# OUTPUT: A list of Barcodes, each one saved as Numpy array
# It can also save: 
# -- BoxPlot0.pdf and BoxPlot1.pdf: Barcodes for 
# 0-dimensional and 1-dimensional features
# -- matBarcode0.txt and matBarcode1.txt: Barcodes for 
# 0-dimensional and 1-dimensional features
# -- matBarcodeTOTAL.txt: All barcodes...
# 1st column: Dimension of the topological feature
# 2nd column: time-birth
# 3rd column: time-death
# -------------------------------------------------------
# PARAMETERS
# -- nameFolderNet: CSV-input files
# -- NVertices: Number of vertices
# -- scaleParameter: Maximum epsilon for filtration,
# Maximum edge weight, keep it on 1.0
# -- maxDimHoles: Maximum dimension features.
# Value of 2 is good, to obtain bigger features it is necessary
# to input no-sparse networks.
# -- sizeWindow: Number of graphs used to compute Zigzag
# -------------------------------------------------------


#%% Libraries 
import EvDyNET
import matplotlib.pyplot as plt
from scipy.spatial.distance import squareform 
from scipy.sparse.csgraph import shortest_path
from IPython.display import display
import numpy as np
import gudhi as gd 
import dionysus as d 
import time  
import ray # For parallel Computing

# Number of Cores
NCores = 7  

# Using parallel
ray.init(num_cpus = NCores) # Specify this system has 4 CPUs

#%% FUNCTIONS DEFINITION
# -------  PERSISTENCE HOMOLOGY ON DYNAMIC GRAPHS ----------
# NODE-based FILTRATION runs on the features of each node and it is constrained 
# by the connectivity between nodes... 
# Edge-based FILTRATION have the option of using Power Filtration (Very computational Expensive),  
# the default option (powerFiltration = False) runs a simply filtration on edges
@ray.remote # Decorator for Parallel Computing 
def Persistence_Dynamic_Network(DyNet, iG, param_INput):
    # Assign Input Parameters
    lsATTR_1 = param_INput[0] # e.g., Degree, Katz, Transaction, Volume...
    lsATTR_2 = param_INput[1] # Options: 1) Node, 2) Edge 
    powerFiltration_ATTR_2 = param_INput[2] # Options True / False -> Only for typeATTR='Edge'  
    # Scale Parameter (Maximum) ### Keep it on 1.0 ###
    # However, for degree filtration this value is going to automatically change
    # in the middle of the code, such that scaleParameter=maximum possible scaleParameter
    # factorScaleParameter * scaleParameter such that 1.0 means using the maximum possible scale parameter
    maxDimHoles = param_INput[3] # Maximum Dimension of Holes (It means.. 0 and 1) 
    MGT = param_INput[4] # Type of Scalar division for tunning: [0] Linear min+factor(max-min), [1] Quantile-based

    # --- Find all values in each graph, for filtration  
    # Type of Filltration  
    if(lsATTR_1[1]=='Node'): # For Node Filtration
        if(lsATTR_1[0] in DyNet.df_nodes[iG]):
            vecAux = DyNet.df_nodes[iG][lsATTR_1[0]].to_numpy()  
    elif(lsATTR_1[1]=='Edge'): # For Edge Filtration
        if(lsATTR_1[0] in DyNet.df_edges[iG]):
            vecAux = DyNet.df_edges[iG][lsATTR_1[0]].to_numpy()  
    vecAux = vecAux[vecAux!=np.inf] # To remove np.INF values
    # For statistics
    if(len(vecAux)>0): # There are elements
        arrayForQuantile = vecAux.copy()
    else: # There are not elements
        arrayForQuantile = np.array([0]) # To avoid error when computing quantile
    maxScaleFilt = np.amax(arrayForQuantile)
    minScaleFilt = np.amin(arrayForQuantile)     
    
    # --- Build Filtered Network as a Dynamic Network
    # Filtered Dynamic Network
    DN_filt = EvDyNET.Dynamic_Network(False, False, DyNet.df_property.to_numpy(), DyNet.df_property.columns.tolist(), DyNet.NOTES)  
    for iRow, scl_attr in enumerate(np.linspace(0.0, 1, num=MGT)):
        #print(f'*** Working on iRow: {iRow}    ***')
        #print(f'Building new Filtered-Dynamic Network {scl_attr}')

        # Find the extra scalar value for filtration 
        if(lsATTR_1[2]==0): # Via Linear Interval-Division
            tuneFilt = minScaleFilt + scl_attr*(maxScaleFilt-minScaleFilt)
        elif(lsATTR_1[2]==1):# Quantile-based Interval-Division
            tuneFilt = np.quantile(arrayForQuantile, scl_attr) 
        # Graph 
        matGRAPH = DyNet.df_features[iG].to_numpy()  
        nameATTgraph = list(DyNet.df_features[iG].columns)  
        # Auxiliary Dataframe for Edges
        df_Edges_Aux = DyNet.df_edges[iG].copy()  
        # Surviving Edges after Node/Edge Filtration  
        if(lsATTR_1[1]=='Node'):  
            if(lsATTR_1[0] in DyNet.df_nodes[iG]): # Attribute exist in Node-DataFrame 
                # Select nodes according with filtration  
                valFilt = DyNet.df_nodes[iG][lsATTR_1[0]].to_numpy().copy() # Values
                aNoSel = DyNet.df_nodes[iG].ID.to_numpy()[valFilt>tuneFilt] # Node-ID
                # Delete Edges which include that node 
                df_Edges_Aux = df_Edges_Aux[df_Edges_Aux.From.isin(aNoSel.tolist())==False]
                df_Edges_Aux = df_Edges_Aux[df_Edges_Aux.To.isin(aNoSel.tolist())==False] 
        elif(lsATTR_1[1]=='Edge'):  
            if(lsATTR_1[0] in DyNet.df_edges[iG]): # Attribute exist in Node-DataFrame 
                # Select edges according with filtration  
                df_Edges_Aux = df_Edges_Aux[df_Edges_Aux[lsATTR_1[0]]<=tuneFilt].copy()   
        # Generate the new filtered network         
        # HERE !!!! rewrite !!!! ************
        # Node attributres and Edge Attributes? ***
        # Add 'Exist in both' and additional attributes: nodattr_ & edgattr_
        if(df_Edges_Aux.shape[0]>0): # There is 'at least' one row (i.e. one edge)
            # Nodes
            arrNodes = np.unique(df_Edges_Aux[['From', 'To']].to_numpy().flatten())
            matNODE = np.ones((len(arrNodes), 1)).astype(int)
            nameATTnode = ['Exist']
            # To add node attributes
            lsCol_Nodes = np.array(list(DyNet.df_nodes[iG].columns))
            idxNode = np.argwhere( np.array([a.split('_')[0] for a in lsCol_Nodes])=='nodattr' ).flatten()
            if(len(idxNode)>0): # If there are attributes
                matNODEextra = DyNet.df_nodes[iG][lsCol_Nodes[idxNode]].to_numpy().reshape((-1,len(idxNode)))
                matNODE = np.concatenate((matNODE, matNODEextra[arrNodes,:]), axis=1) # Concatenate
                nameATTnode = nameATTnode + lsCol_Nodes[idxNode].tolist()
            # Edges
            arrEdges = df_Edges_Aux[['From_Label', 'To_Label']].to_numpy()  
            matEDGE = np.ones((df_Edges_Aux.shape[0], 1)).astype(int)
            nameATTedge = ['Exist']
            # To add edge attributes
            lsCol_Edges = np.array(list(DyNet.df_edges[iG].columns))
            idxEdge = np.argwhere( np.array([a.split('_')[0] for a in lsCol_Edges])=='edgattr' ).flatten()
            if(len(idxEdge)>0): # If there are attributes
                #matEDGEextra = DyNet.df_edges[iG][lsCol_Edges[idxEdge]].to_numpy().reshape((-1,len(idxEdge)))
                matEDGEextra = df_Edges_Aux[lsCol_Edges[idxEdge]].to_numpy().reshape((-1,len(idxEdge)))
                matEDGE = np.concatenate((matEDGE, matEDGEextra), axis=1) # Concatenate
                nameATTedge = nameATTedge + lsCol_Edges[idxEdge].tolist()
        else: # There are not edges
            # Nodes
            arrNodes = np.array([])  
            matNODE = np.array([[]]) 
            nameATTnode = []
            # Edges
            arrEdges = np.array([[]])  
            matEDGE = np.array([[]]) 
            nameATTedge = []  
        # Adding Network
        extended_Extra = True
        flagADD = DN_filt.addNet(arrNodes, matNODE, nameATTnode, arrEdges, matEDGE, nameATTedge, matGRAPH, nameATTgraph, extended_Extra)
        if(not flagADD):
            print(f'\n *>*>*> ERROR: when adding filtered-network net index {iG}')
            break 

    # --- To Compute Persistent Diagram on each filtered-Graph 
    lsArray_BarCode = []
    # --- and BarCode    
    nameATTR = lsATTR_2[0] # e.g., Transaction, Volume...
    typeATTR = lsATTR_2[1] # Options: 1) Node, 2) Edge 
    powerFiltration = powerFiltration_ATTR_2
    for kfG in range(DN_filt.number_nets):
        # Number of Vertices
        NVertices = DN_filt.number_nodes[kfG]
        
        if(NVertices>0):
            # To construct Distance Matrix
            A = np.full((NVertices, NVertices), np.inf)
            maxScaleParam = 0 # To at least assign 0 when all DataFrames are empty 
            # For Both OPTIONS: Node and Edge
            if(typeATTR=='Node'):
                matValNode = np.zeros(NVertices)  
                # Build Adjacency Matrix for *Node* Filtration
                if(DN_filt.df_edges[kfG].shape[0]>0): # Otherwise keep np.Inf
                    idFrom = DN_filt.df_edges[kfG].From.to_numpy()
                    idTo = DN_filt.df_edges[kfG].To.to_numpy()
                    matValNode = DN_filt.df_nodes[kfG][nameATTR].to_numpy()
                    A[idFrom, idTo] = np.maximum(matValNode[idFrom], matValNode[idTo]) # Maximum among two nodes
                    A[idTo, idFrom] = A[idFrom, idTo]  # Symmetric Matrix
                    #A[range(NVertices), range(NVertices)] = valNode # Diagonal as same as values in nodes
                    # To Update Maximum and Minimum Parameters
                    auxNodeArray = matValNode.copy()
                    auxNodeArray = auxNodeArray[auxNodeArray!=np.inf]
                    maxScaleParam = np.amax(auxNodeArray) if(np.amax(auxNodeArray)>maxScaleParam) else maxScaleParam    
                # Diagonal = 0
                A[range(NVertices), range(NVertices)] = 0 # Diagonal
                # In case of np.INF: to assign the upper-limit parameter, i.e. beyond 'scaleParameter'    
                matValNode[matValNode==np.inf] = 2*maxScaleParam + 1
            elif(typeATTR=='Edge'):
                # Build Adjacency Matrix for *Edge* Filtration 
                if(DN_filt.df_edges[kfG].shape[0]>0): # Otherwise keep np.Inf
                    idFrom = DN_filt.df_edges[kfG].From.to_numpy()
                    idTo = DN_filt.df_edges[kfG].To.to_numpy()
                    auxValEdge = DN_filt.df_edges[kfG][nameATTR].to_numpy() # For Most of Variables
                    #auxValEdge = 10000*DN_filt.df_edges[tx][nameATTR].to_numpy() # For VOLUME with POWER because numbers are too small
                    A[idFrom, idTo] = auxValEdge # Edge value
                    A[idTo, idFrom] = A[idFrom, idTo]  # Symmetric Matrix
                    if(powerFiltration==True): # # https://stackoverflow.com/questions/53074947/examples-for-search-graph-using-scipy/53078901
                        # Using Power Filtration
                        DPower, Pred = shortest_path(A, directed=False, method='FW', return_predecessors=True, unweighted=False) 
                        A = DPower.copy()
                        auxValEdge = squareform(A)
                    # To remove np.INF values 
                    auxValEdge = auxValEdge[auxValEdge!=np.inf]
                    # To Update Maximum and Minimum Parameters
                    maxScaleParam = np.amax(auxValEdge) if(np.amax(auxValEdge)>maxScaleParam) else maxScaleParam    
                # Diagonal = 0
                A[range(NVertices), range(NVertices)] = 0 # Diagonal 
            
            # To assign the upper-limit parameter, i.e. beyond 'scaleParameter'    
            A[A==np.inf] = 2*maxScaleParam + 1  

            # --- Building Vietoris-Rips Complex  
            tuneSParam = maxScaleParam
            # For Both OPTIONS: Node and Edge 
            if(typeATTR=='Node'):
                rips_complex = gd.RipsComplex(distance_matrix = A, max_edge_length=tuneSParam)  
                simplex_tree = rips_complex.create_simplex_tree(max_dimension = maxDimHoles) 
                # Lets add new time's values to single vertices and show new filtration
                both_degree = matValNode.tolist()   
                for j in range(NVertices): 
                    simplex_tree.assign_filtration([j], both_degree[j])
                # To solve issues and make it a filtered simplicial complex
                # Notice that inserting an edge automatically inserts its vertices 
                # (if they were not already in the complex) in order to satisfy 
                # **the inclusion property** of a filtered complex: any simplex with 
                # filtration value $t$ must have all its faces in the filtered complex, 
                # with filtration values smaller than or equal to $t$. 
                simplex_tree.make_filtration_non_decreasing()
                ## To show the filtration using GUDHI code
                # for splx in simplex_tree.get_filtration():
                #     print(splx)    

                # To move GUDHI's filtration as INPUT for DYONISUS CODE (i.e. Dyonisus filtration)
                ripsAux = d.Filtration()  # Create the new filtration
                # Add each simplex 
                for splx in simplex_tree.get_filtration():
                    vertices, valtime = splx
                    ripsAux.append(d.Simplex(vertices, valtime)) 
                #f_moved.sort() # It seems it is not a necessary command
            elif(typeATTR=='Edge'):
                # Distance in condensed form
                pDisAux = squareform( A )
                # Rips Filtration
                ripsAux = d.fill_rips(pDisAux, maxDimHoles, tuneSParam)

            # --- The Barcode 
            # Get the Barcode
            homoPers = d.homology_persistence(ripsAux)
            G_dgms = d.init_diagrams(homoPers, ripsAux)
            # To Convert in Barcodes    
            BCfull = np.zeros((0, 3))   
            for i, g_dgm in enumerate(G_dgms):
                if(i<maxDimHoles):
                    #print(i)
                    #print(g_dgm)
                    BCD = np.zeros((len(g_dgm), 3))
                    for j, pb in enumerate(g_dgm):
                        BCD[j] = np.array([i, pb.birth, pb.death]) 
                        #print(i, pb.birth, pb.death)
                    # To Concatenate Barcode
                    BCfull = np.concatenate((BCfull, BCD))  
            # To Replace INF in Barcode by Maximum Scale Parameter found
            BCfull[ BCfull==np.inf ] = maxScaleParam
        
        else: # There are not any vertices
            BCfull = np.zeros((0, 3))   
        
        #print(BCfull.shape)  
        #print(BCfull)  

        # To save Barcode ( List of Barcodes ) 
        lsArray_BarCode.append( BCfull )
    
    # Return List of Barcodes
    return(lsArray_BarCode)

    

# %%
#%% ************* MAIN CODE SECTION *******************
#######################################################################
#######################################################################
#%% Parameters 
# --- Dataset
#path_dataset = 'DANYCH/TUDataset/TUD_BZR_MD.dyne'    # 0, 1, '', 'node_labels'    #  1, 1, 'node_labels', 'node_labels'
#path_dataset = 'DANYCH/TUDataset/TUD_COX2_MD.dyne'    # 0, 1, '', 'node_labels'    #  1, 1, 'node_labels', 'node_labels'
#path_dataset = 'DANYCH/TUDataset/TUD_DHFR_MD.dyne'    # 0, 1, '', 'node_labels'    #  1, 1, 'node_labels', 'node_labels'
#path_dataset = 'DANYCH/TUDataset/TUD_ER_MD.dyne'    # 0, 1, '', 'node_labels'    #  1, 1, 'node_labels', 'node_labels'

#path_dataset = 'DANYCH/TUDataset/TUD_BZR.dyne'   # 3, 0, 'node_attr', ''    #  1, 0, 'node_labels', ''
#path_dataset = 'DANYCH/TUDataset/TUD_COX2.dyne'   # 3, 0, 'node_attr', ''    #  1, 0, 'node_labels', ''
#path_dataset = 'DANYCH/TUDataset/TUD_DHFR.dyne'   # 3, 0, 'node_attr', ''    #  1, 0, 'node_labels', ''
#path_dataset = 'DANYCH/TUDataset/TUD_PROTEINS.dyne'   # 1, 0, 'node_attr', ''     #  1, 0, 'node_labels', ''

#path_dataset = 'DANYCH/TUDataset/TUD_IMDB-BINARY.dyne'   # 0, 0, '', ''
#path_dataset = 'DANYCH/TUDataset/TUD_MUTAG.dyne'   # 0, 0, '', ''
path_dataset = 'DANYCH/TUDataset/TUD_IMDB-MULTI.dyne'   # 0, 0, '', ''
#path_dataset = 'DANYCH/TUDataset/TUD_NCI1.dyne'   # 0, 0, '', ''

# --- Where to save...
pathSaveOut = 'EXPERIMENT/MP_2D/' # Only the folder  

# --- Attributes for Multipersistence Zigzag 
### for-> TUD_BZR_MD, TUD_COX2_MD, TUD_DHFR_MD, TUD_ER_MD
###lsATTR_1, lsATTR_2 = ['Degree_edgattr_0', 'Node', 1], ['edgattr_0', 'Edge']
#lsATTR_1, lsATTR_2 = ['nodattr_0', 'Node', 1], ['edgattr_0', 'Edge']

### for-> TUD_BZR, TUD_COX2, TUD_DHFR, TUD_ER  
###lsATTR_1, lsATTR_2 = ['Closeness', 'Node', 1], ['Betweenness', 'Edge']
#lsATTR_1, lsATTR_2 = ['nodattr_0', 'Node', 1], ['Betweenness', 'Edge']

### for-> TUD_IMDB-BINARY, TUD_IMDB-MULTI, TUD_MUTAG, TUD_NCI1
lsATTR_1, lsATTR_2 = ['Katz', 'Node', 1], ['Ricci_Positive', 'Edge']   

# --- Other Parameters
powerFiltration_ATTR_2 = True
maxDimHoles = 2
# MultiGrid Parameters 
MGT = 50    # 8  ---  50 

#%% Open Dataset and Fill-Out Additional Parameters
print('Open Dataset ...')
DyNet = EvDyNET.Load_Dynamic_Network(path_dataset)  
print('Done!')  
#DyNet.show('last')

# Period of Dynamic Network  
indexPeriod = np.arange(0, DyNet.number_nets) # Numpy Array with indices of Nets
#indexPeriod = indexPeriod[:30]

#%% *******  MULTIPARAMETER ZIGZAG PERSISTENCE  **********
print('Computing MULTIPARAMETER PERSISTENCE  ')  
# Start acumulating time  
start_time_MPZ = time.time()  
# To save the Data
#MPGrid_lis = []

# --- Create FUTURES
lsFUT = [] 

# Foor-Loop MULTIPARAMETER ZIGZAG PERSISTENCE
for ik, iG in enumerate(indexPeriod): # For all the selected nets  
    #print(' *** Net: ', iG)  
    #### INPUT  
    param_INput = []
    param_INput.append(lsATTR_1) # e.g., Degree, Katz, Transaction, Volume...
    param_INput.append(lsATTR_2) # Options: 1) Node, 2) Edge 
    param_INput.append(powerFiltration_ATTR_2) # Options True / False -> Only for typeATTR='Edge'  
    # Scale Parameter (Maximum) ### Keep it on 1.0 ###
    # However, for degree filtration this value is going to automatically change
    # in the middle of the code, such that scaleParameter=maximum possible scaleParameter
    # factorScaleParameter * scaleParameter such that 1.0 means using the maximum possible scale parameter
    param_INput.append(maxDimHoles) # Maximum Dimension of Holes (It means.. 0 and 1) 
    param_INput.append(MGT) # Type of Scalar division for tunning: [0] Linear min+factor(max-min), [1] Quantile-based
    #### CALLING THE FUNCTION    
    lsFUT.append( Persistence_Dynamic_Network.remote(DyNet, iG, param_INput) )
    
    #lsBC = Persistence_Dynamic_Network(DyNet, iG, param_INput)  
    ## To save barcode
    #MPGrid_lis.append(lsBC)

# --- Run the tasks in parallel  
start = time.time()
MPGrid_lis = ray.get( lsFUT ) 
print("duration =", time.time() - start)
print("len(results) = ", len(MPGrid_lis))
#print( [a.shape[0] for a in results] )

# --- Shutdown the parallelization library
ray.shutdown()  

#%% To Save Barcodes 
# Saving the results in a Numpy Array
# MPGrid[ i, k ], where 
# i-> index-Network
# k-> First Filtration, the lower index means lower filtration-value
MPGrid = np.array(MPGrid_lis, dtype=object)  
# To Save in a Numpy File  

# Time Variable
arrayTime = []
arrayTime.append(time.time()-start_time_MPZ) # Seconds
arrayTime.append((time.time()-start_time_MPZ)/60) # Minutes
arrayTime.append((time.time()-start_time_MPZ)/(60*60)) # Hours
arrayTime = np.array(arrayTime)
# Dataset Variables
arrayParams = [] 
arrayParams.append(path_dataset)
arrayParams.append(lsATTR_1)
arrayParams.append(lsATTR_2 )  
arrayParams.append(powerFiltration_ATTR_2)
arrayParams.append(maxDimHoles)
arrayParams.append(MGT) 
arrayParams.append(indexPeriod)
arrayParams = np.array(arrayParams, dtype=object)
# The class of each net
arrayClass = np.zeros(len(indexPeriod))
for iNet in indexPeriod:
    arrayClass[iNet] = DyNet.df_features[iNet]['class'].to_numpy()[0]
# Saving the Obtained MultiParameter Zigzag Persistence and other data
nameFilt = lsATTR_1[0] + '_' + lsATTR_1[1] + '_and_' 
if((powerFiltration_ATTR_2==True) and (lsATTR_2[1]=='Edge')):
    nameFilt = nameFilt + 'Power_'
nameFilt = nameFilt + lsATTR_2[0] + '_' + lsATTR_2[1]  
nameSaveOut = pathSaveOut + 'MP_'+path_dataset.split('/')[-1].split('.')[0]+'_Filt1_'+str(MGT)+'_'+nameFilt+'.npz'   
np.savez(nameSaveOut, MPGrid=MPGrid, arrayClass=arrayClass, arrayTime=arrayTime, arrayParams=arrayParams) 

#%% Load Saved File and Show Shape of Barcodes
npzfile = np.load(nameSaveOut, allow_pickle=True)  
display(npzfile.files)
display(npzfile['MPGrid'].shape)
display(npzfile['arrayTime'])
display(npzfile['arrayParams'])
display(npzfile['arrayClass'])
# for iNet in range(npzfile['MPGrid'].shape[0]):
#     print(' *** Net: ', iNet, end='  >>> ')
#     for iFil in range(npzfile['MPGrid'].shape[1]):
#         print(npzfile['MPGrid'][iNet, iFil].shape, end='')
#     print(' ')

#%%

#%%

