import sys
import numpy as np
from gph import ripser_parallel
from scipy.stats import entropy
import networkx as nx
from numpy.linalg import eig

import warnings
warnings.filterwarnings("ignore")

sys.path.append('./ph_simple/lib/')
import ph_simple

def trans(r):
    d0 = r['dgms'][0]
    A0 = np.zeros((d0.shape[0], 3))
    A0[:,0:2] = d0
    A0[:,2] = 0

    d1 = r['dgms'][1]
    A1 = np.zeros((d1.shape[0], 3))
    A1[:,0:2] = d1
    A1[:,2] = 1
    
    d2 = r['dgms'][2]
    A2 = np.zeros((d2.shape[0], 3))
    A2[:,0:2] = d2
    A2[:,2] = 2
    
    return np.concatenate((A0, A1, A2), axis = 0)

def f_count(barc):
    return barc.shape[0]

def f_count_t(barc, thrs):
    return barc[barc < thrs].shape[0]

def f_max(barc):
    if len(barc):
        return np.max(barc[:, 1] - barc[:, 0])
    else:
        return 0.

def f_min(barc):
    if len(barc):
        return np.min(barc[:, 1] - barc[:, 0])
    else:
        return 0.

def f_mean(barc):
    if len(barc):
        return np.mean(barc[:, 1] - barc[:, 0])
    else:
        return 0.

def f_median(barc):
    if len(barc):
        return np.median(barc[:, 1] - barc[:, 0])
    else:
        return 0.

def f_sum(barc):
    return np.sum(barc[:, 1] - barc[:, 0])

def f_std(barc):
    if len(barc):
        return np.std(barc)
    else:
        return 0.

def f_entropy(barc):
    if np.sum(barc):
        p = (barc[:, 1] - barc[:, 0]) / np.sum(barc)
        return entropy(p)
    else:
        return 0.

def generate_features(dgm):
    f = []

    for dim in [0, 1]:
        if len(dgm['dgms']) >= dim + 1:
            barc = dgm['dgms'][dim]
        else:
            barc = np.empty(shape = (0, 2), dtype=np.float32)

        non_ess_barc = barc[barc[:, 1] < np.inf]
        ess_barc = barc[barc[:,1] == np.inf]

        f.append(f_count(non_ess_barc))     # 1
        f.append(f_count(ess_barc))         # 2
        f.append(f_max(non_ess_barc))       # 3
        f.append(f_min(non_ess_barc))       # 4
        f.append(f_mean(non_ess_barc))      # 5
        f.append(f_median(non_ess_barc))    # 6
        f.append(f_sum(non_ess_barc))       # 7
        f.append(f_std(non_ess_barc))       # 8
        f.append(f_entropy(non_ess_barc))   # 9
        f.append(f_count_t(non_ess_barc[:, 0], 0.1))
        f.append(f_count_t(non_ess_barc[:, 0], 0.2))
        f.append(f_count_t(non_ess_barc[:, 0], 0.4))
        f.append(f_count_t(non_ess_barc[:, 0], 0.8))
        f.append(f_count_t(non_ess_barc[:, 1], 0.1))
        f.append(f_count_t(non_ess_barc[:, 1], 0.2))
        f.append(f_count_t(non_ess_barc[:, 1], 0.4))
        f.append(f_count_t(non_ess_barc[:, 1], 0.8))

    return f

def get_features_from_attention_matrix_v1(B, thrs, remove_node = False, maxdim = 1):
    Bdist = 1 - np.maximum(B, np.transpose(B))
    np.fill_diagonal(Bdist, 0)

    f_all = []
    dgm_all = []

    for node in range(B.shape[0]):
        close_nodes = []

        for i in range(Bdist.shape[0]):
            if Bdist[node, i] < thrs:
                close_nodes.append(i)

        if remove_node:
            close_nodes.remove(node)
        

        Bdist_part = Bdist[close_nodes][:, close_nodes]
        f = [B[node, node]] # self-attention

        if len(close_nodes) != 0:
            dgm = ripser_parallel(Bdist_part, metric="precomputed", maxdim=maxdim, n_threads=-1)
            dgm_all.append(dgm)
            f_tda = generate_features(dgm)
        else:
            f_tda = []
            f_tda = [f_tda.append(0) for _ in range(34)]
        f.extend(f_tda)
        f_all.append(f)
        
    return f_all, dgm_all

def get_features_from_attention_matrix_v2(B, thrs, remove_node = False):
    Bdist = 1 - np.maximum(B, np.transpose(B))
    np.fill_diagonal(Bdist, 0)

    G = nx.Graph()
    edges = []

    for i in range(Bdist.shape[0]):
        for j in range(Bdist.shape[1]):
            edges.append((i, j, Bdist[i, j]))

    G.add_weighted_edges_from(edges)

    Bdist2 = np.full(Bdist.shape, np.inf)
    np.fill_diagonal(Bdist2, 0)
    
    f_all = []
    dgm_all = []

    for node in range(B.shape[0]):

        shortest_dist = np.full(B.shape[0], np.inf)
        dijkstra = nx.single_source_dijkstra(G, source = node, cutoff = thrs)

        for k, v in dijkstra[0].items():
            shortest_dist[k] = v

        # add reachable edges
        for i in range(B.shape[0]):
            for j in range(B.shape[1]):
                if shortest_dist[i] + Bdist[i, j] < thrs or shortest_dist[j] + Bdist[i, j] < thrs:
                    Bdist2[i, j] = Bdist[i, j]

        close_nodes = []

        for i in range(B.shape[0]):
            if shortest_dist[i] < thrs:
                close_nodes.append(i)
        
        if remove_node:
            close_nodes.remove(node)

        Bdist_part = Bdist2[close_nodes][:, close_nodes]

        dgm = ripser_parallel(Bdist_part, metric="precomputed", maxdim=2, n_threads=-1)

        f = [B[node, node]] # self-attention
        f_tda = generate_features(dgm)
        f.extend(f_tda)

        f_all.append(f)
        dgm_all.append(dgm)

    return f_all, dgm_all

def get_features_from_attention_matrix_v21(B, thrs, remove_node = False, maxdim = 1):
    Bdist = 1 - np.maximum(B, np.transpose(B))
    np.fill_diagonal(Bdist, 0)

    G = nx.Graph()
    edges = []

    for i in range(Bdist.shape[0]):
        for j in range(Bdist.shape[1]):
            edges.append((i, j, Bdist[i, j]))

    G.add_weighted_edges_from(edges)

    dijkstra = dict(nx.all_pairs_dijkstra_path_length(G, thrs))

    f_all = []
    dgm_all = []

    for node in range(B.shape[0]):

        shortest_dist = dijkstra[node]
        close_nodes = []

        for i in range(B.shape[0]):
            if shortest_dist.get(i, np.inf) < thrs:
                close_nodes.append(i)
        
        if remove_node:
            close_nodes.remove(node)

        Bdist_part = Bdist[close_nodes][:, close_nodes]

        
        f = [B[node, node]] # self-attention

        if Bdist_part.shape != (0,0):
            dgm = ripser_parallel(Bdist_part, metric="precomputed", maxdim=maxdim, n_threads=-1)
            f_tda = generate_features(dgm)
            dgm_all.append(dgm)
        else:
            f_tda = []
            f_tda = [f_tda.append(0) for _ in range(34)]        
    
        f.extend(f_tda)
        f_all.append(f)

    return f_all, dgm_all


def get_features_from_attention_matrix_v3(B, graph_laplacian = False):

    Bdist = 1 - np.maximum(B, np.transpose(B))
    np.fill_diagonal(Bdist, 0)

    edges = np.zeros((2, B.shape[0] * (B.shape[0] - 1)//2), dtype = np.int32)
    w = np.zeros(B.shape[0] * (B.shape[0] - 1)//2, dtype = np.float32)

    k = 0

    for i in range(B.shape[0]):
        for j in range(i):
            if i != j:
                edges[0, k] = i
                edges[1, k] = j
                w[k] = Bdist[i, j]
        
                k += 1

    batch_size = 1
    n_nodes = B.shape[0]
    edge_ptr = np.array([0, edges.shape[1]], dtype = np.int32)
    node_ptr = np.array([0, n_nodes], dtype = np.int32)
    h0 = np.full(n_nodes, -1, dtype = np.int32)
    h0_e = np.full(w.shape, -1, dtype = np.int32)
    h1_e = np.full(w.shape, -1, dtype = np.int32)
    multiprocessing = 0
    filter_cycles = 0

    ph_simple.calc_barcodes_batch_cycles(batch_size, edges, w, edge_ptr, node_ptr, h0, h0_e, h1_e, filter_cycles, multiprocessing)

    min_st = np.full(n_nodes, np.inf)
    max_st = np.zeros(n_nodes)
    count_st = np.zeros(n_nodes)
    sum_st = np.zeros(n_nodes)
    mean_st = np.zeros(n_nodes)
    adj = np.zeros((n_nodes, n_nodes))

    for k in h0:

        if k == -1:
            continue

        i, j = edges[:, k]

        adj[i, j] = 1
        adj[j, i] = 1

        min_st[i] = min(min_st[i], w[k])
        min_st[j] = min(min_st[j], w[k])

        max_st[i] = max(max_st[i], w[k])
        max_st[j] = max(max_st[j], w[k])

        count_st[i] += 1
        count_st[j] += 1

        sum_st[i] += w[k]
        sum_st[j] += w[k]

    mean_st = sum_st / count_st
    features = np.stack((min_st, max_st, count_st, sum_st, mean_st), axis = 1)

    if graph_laplacian:
        nx_graph = nx.Graph(adj)
        NL = nx.normalized_laplacian_matrix(nx_graph)

        vals, vecs = eig(NL.toarray())
        vals_enumed = list(enumerate(vals))
        vals_enumed = sorted(vals_enumed, key = lambda x : x[1])[1:]

        top_eigv = np.zeros((n_nodes, 5))
        top_eigv[:, :5] = np.array([x[1].real for x in vals_enumed[:5]])

        selected_vecs = np.zeros((n_nodes, 5))

        for i in range(5):
            vec_id = vals_enumed[i][0]
            selected_vecs[:, i] = vecs[:, vec_id]

        features = np.hstack((features, top_eigv, np.abs(selected_vecs)))

    return features

def get_features_from_attention_matrix_v31(B):

    Bdist = 1 - np.maximum(B, np.transpose(B))
    np.fill_diagonal(Bdist, 0)

    edges = np.zeros((2, B.shape[0] * (B.shape[0] - 1)//2), dtype = np.int32)
    w = np.zeros(B.shape[0] * (B.shape[0] - 1)//2, dtype = np.float32)

    k = 0

    for i in range(B.shape[0]):
        for j in range(i):
            if i != j:
                edges[0, k] = i
                edges[1, k] = j
                w[k] = Bdist[i, j]
        
                k += 1

    batch_size = 1
    n_nodes = B.shape[0]
    edge_ptr = np.array([0, edges.shape[1]], dtype = np.int32)
    node_ptr = np.array([0, n_nodes], dtype = np.int32)
    h0 = np.full(n_nodes, -1, dtype = np.int32)
    h0_e = np.full(w.shape, -1, dtype = np.int32)
    h1_e = np.full(w.shape, -1, dtype = np.int32)
    multiprocessing = 0
    filter_cycles = 0

    ph_simple.calc_barcodes_batch_cycles(batch_size, edges, w, edge_ptr, node_ptr, h0, h0_e, h1_e, filter_cycles, multiprocessing)

    min_st = np.full(n_nodes, np.inf)
    max_st = np.zeros(n_nodes)
    count_st = np.zeros(n_nodes)
    sum_st = np.zeros(n_nodes)
    mean_st = np.zeros(n_nodes)
    adj = np.zeros((n_nodes, n_nodes))
    adj_w = np.zeros((n_nodes, n_nodes))

    for k in h0:

        if k == -1:
            continue

        i, j = edges[:, k]

        adj[i, j] = 1
        adj[j, i] = 1
        adj_w[i, j] = 1 - w[k]
        adj_w[j, i] = 1 - w[k]          

        min_st[i] = min(min_st[i], w[k])
        min_st[j] = min(min_st[j], w[k])

        max_st[i] = max(max_st[i], w[k])
        max_st[j] = max(max_st[j], w[k])

        count_st[i] += 1
        count_st[j] += 1

        sum_st[i] += w[k]
        sum_st[j] += w[k]

    mean_st = sum_st / count_st
    asym_B = np.sum(np.abs(B - np.transpose(B)), axis = 0)

    features = np.stack((min_st, max_st, count_st, sum_st, mean_st, np.diagonal(B), asym_B), axis = 1)

    return features #, adj, adj_w

def get_features_from_attention_matrix_v4(B):

    Bdist = 1 - np.maximum(B, np.transpose(B))
    np.fill_diagonal(Bdist, 0)

    edges = np.zeros((2, B.shape[0] * (B.shape[0] - 1)//2), dtype = np.int32)
    w = np.zeros(B.shape[0] * (B.shape[0] - 1)//2, dtype = np.float32)

    k = 0

    for i in range(B.shape[0]):
        for j in range(i):
            if i != j:
                edges[0, k] = i
                edges[1, k] = j
                w[k] = Bdist[i, j]
        
                k += 1

    batch_size = 1
    n_nodes = B.shape[0]
    edge_ptr = np.array([0, edges.shape[1]], dtype = np.int32)
    node_ptr = np.array([0, n_nodes], dtype = np.int32)
    h0 = np.full(n_nodes, -1, dtype = np.int32)
    h0_e = np.full(w.shape, -1, dtype = np.int32)
    h1_e = np.full(w.shape, -1, dtype = np.int32)
    multiprocessing = 0
    filter_cycles = 0

    ph_simple.calc_barcodes_batch_cycles(batch_size, edges, w, edge_ptr, node_ptr, h0, h0_e, h1_e, filter_cycles, multiprocessing)

    adj = np.zeros((n_nodes, n_nodes))

    for k in h0:

        if k == -1:
            continue

        i, j = edges[:, k]

        adj[i, j] = 1
        adj[j, i] = 1
    nx_graph = nx.Graph(adj)
    betweenness = nx.betweenness_centrality(nx_graph)  

    return np.array(list(betweenness.values()), dtype=np.float16)

    
