import networkx as nx
import torch
import numpy as np
import os.path as osp
from scipy.spatial import distance
from torch_geometric.utils.convert import to_networkx
from datasets.struc2vec.models.struc2vec import Struc2Vec
from collections import Counter


def analysis_graph_structure_statis_info(G, undirected):
    structure_statis_info = {}
    y = G.y
    structure_statis_info["num_nodes"] = G.num_nodes
    structure_statis_info["num_edges"] = G.num_edges
    structure_statis_info["average_degree"] = G.num_edges/G.num_nodes
    structure_statis_info["density"] = G.num_edges / (G.num_nodes * (G.num_nodes - 1))
    graph_nx = to_networkx(G, to_undirected=undirected)
    structure_statis_info["label_distribution"] = dict(sorted(dict(Counter(y.tolist())).items(), key=lambda x: x[0]))
    structure_statis_info["average_path_length"], structure_statis_info["isolate_nodes_num"] = average_shortest_path_length_for_all(to_networkx(G, to_undirected=True))
    nodes_degree_info = dict(nx.degree(graph_nx))
    structure_statis_info["nodes_degree_info"] =list(sorted(nodes_degree_info.items(), key = lambda nodes_degree_info:(nodes_degree_info[1], nodes_degree_info[0]), reverse=True))[:10]
    return structure_statis_info


def analysis_graph_structure_homo_hete_info(G):
    structure_homo_hete_label_info = {}
    structure_homo_hete_label_info["node_homophily"] = label_node_homogeneity(G)
    structure_homo_hete_label_info["edge_homophily"] = label_edge_homogeneity(G)
    structure_homo_hete_feature_info = {}
    return structure_homo_hete_label_info, structure_homo_hete_feature_info


def average_shortest_path_length_for_all(G):
    tmp_G=G.copy()
    isolate_nodes_num = len(list(nx.isolates(tmp_G)))
    if nx.is_connected(G):
        average = nx.average_shortest_path_length(tmp_G)
    else:
        iso_nodes = nx.isolates(G)
        tmp_G.remove_nodes_from(iso_nodes)
        if nx.is_connected(tmp_G):
            average = nx.average_shortest_path_length(tmp_G)
        else:
            subgraphs = list(tmp_G.subgraph(i) for i in list(nx.connected_components(tmp_G)))
            average = 0
            for sb in subgraphs:
                average += nx.average_shortest_path_length(sb)
            average /= (len(subgraphs)*1.0)
    return average, isolate_nodes_num

def label_node_homogeneity(G):
    num_nodes = G.num_nodes
    homophily = 0
    for edge_u in range(num_nodes):
        hit = 0
        edge_v_list = G.edge_index[1][torch.where(G.edge_index[0] == edge_u)]
        if len(edge_v_list) != 0:
            for i in range(len(edge_v_list)):
                edge_v = edge_v_list[i]
                if G.y[edge_u] == G.y[edge_v]:
                    hit += 1
            homophily += hit / len(edge_v_list)
    homophily /= num_nodes
    return homophily

def label_edge_homogeneity(G):
    num_edges = G.num_edges
    homophily = 0
    for i in range(num_edges):
        if G.y[G.edge_index[0][i]] == G.y[G.edge_index[1][i]]:
            homophily += 1
    homophily /= num_edges
    return homophily

def feature_node_homogeneity(G):
    num_nodes = G.num_nodes
    homophily = 0
    for edge_u in range(num_nodes):
        sim_list = []
        hit = 0
        edge_v_list = G.edge_index[1][torch.where(G.edge_index[0] == edge_u)]
        if len(edge_v_list) != 0:
            for i in range(len(edge_v_list)):
                edge_v = edge_v_list[i]
                sim = (1 - distance.cosine(G.x[edge_u], G.x[edge_v]))
                sim_list.append(sim)
                hit += sim
            hit /= len(edge_v_list)
            sim_min = min(sim_list)
            sim_max = max(sim_list)
            if (sim_max - sim_min) != 0:
                homophily += (hit-sim_min) / (sim_max - sim_min)
            else:
                homophily += hit
    homophily /= num_nodes


    return homophily

def feature_edge_homogeneity(G):
    num_edges = G.num_edges
    homophily = 0
    sim_list = []
    for i in range(num_edges):
        sim = (1 - distance.cosine(G.x[G.edge_index[0][i]], G.x[G.edge_index[1][i]]))
        sim_list.append(sim)
        homophily += sim
    homophily /= num_edges
    sim_list = list(filter((0.0).__ne__, sim_list))
    sim_list = list(filter((1).__ne__, sim_list))
    sim_min = min(sim_list)
    sim_max = max(sim_list)
    homophily = (homophily-sim_min) / (sim_max - sim_min)
    return homophily

def gromov_hyperbolicity(G, undirected, num_samples=5000):
    graph_nx = to_networkx(G, to_undirected=undirected)
    tempgraph = graph_nx.copy()
    hyps = []
    if nx.is_connected(tempgraph):
        for i in range(num_samples):
            node_tuple = np.random.choice(graph_nx.nodes(), 4, replace=False)
            try:
                d01 = nx.shortest_path_length(graph_nx, source=node_tuple[0], target=node_tuple[1], weight=None)
                d23 = nx.shortest_path_length(graph_nx, source=node_tuple[2], target=node_tuple[3], weight=None)
                d02 = nx.shortest_path_length(graph_nx, source=node_tuple[0], target=node_tuple[2], weight=None)
                d13 = nx.shortest_path_length(graph_nx, source=node_tuple[1], target=node_tuple[3], weight=None)
                d03 = nx.shortest_path_length(graph_nx, source=node_tuple[0], target=node_tuple[3], weight=None)
                d12 = nx.shortest_path_length(graph_nx, source=node_tuple[1], target=node_tuple[2], weight=None)
                
                s = [d01 + d23, d02 + d13, d03 + d12]
                s.sort()
                hyps.append((s[-1] - s[-2]) / 2)
            except Exception as e:
                continue
    else:
        subgraphs = list(graph_nx.subgraph(i) for i in list(nx.connected_components(graph_nx)))
        for sb in subgraphs:
            for i in range(num_samples):
                try:
                    node_tuple = np.random.choice(sb.nodes(), 4, replace=False)
                    d01 = nx.shortest_path_length(sb, source=node_tuple[0], target=node_tuple[1], weight=None)
                    d23 = nx.shortest_path_length(sb, source=node_tuple[2], target=node_tuple[3], weight=None)
                    d02 = nx.shortest_path_length(sb, source=node_tuple[0], target=node_tuple[2], weight=None)
                    d13 = nx.shortest_path_length(sb, source=node_tuple[1], target=node_tuple[3], weight=None)
                    d03 = nx.shortest_path_length(sb, source=node_tuple[0], target=node_tuple[3], weight=None)
                    d12 = nx.shortest_path_length(sb, source=node_tuple[1], target=node_tuple[2], weight=None)
                    
                    s = [d01 + d23, d02 + d13, d03 + d12]
                    s.sort()
                    hyps.append((s[-1] - s[-2]) / 2)
                except Exception as e:
                    continue
    return np.mean(hyps)

def idx_to_mask(index, size):
    mask = torch.zeros((size, ), dtype=torch.bool)
    mask[index] = 1
    return mask

def remove_duplicate_two_dimension_list_element(input_list):
    tmp_list = []
    for two_dimension in input_list :
        if two_dimension not in tmp_list:
            tmp_list.append(two_dimension)
    return tmp_list

def struc2vec_sim_distribution(name, subgraph, cliend_id):
    struc2vec_graph_root_path = "./datasets/struc2vec/graph"
    struc2vec_graph_file_path = osp.join(struc2vec_graph_root_path, "{}{}.edgelist".format(name, cliend_id))
    edge_index = subgraph.edge_index.numpy()
    num_edges = subgraph.num_edges
    try:
        G = nx.read_edgelist(struc2vec_graph_file_path, create_using=nx.DiGraph(), nodetype=None,
                            data=[('weight', int)])
    except:
        edge_list = []
        for i in range(num_edges):
            edge_list.append((edge_index[0][i],edge_index[1][i]))
        file = open(struc2vec_graph_file_path, 'w')
        for fp in edge_list:
            file.write(str(fp[0]))
            file.write(" ")
            file.write(str(fp[1]))
            file.write('\n')
        file.close()
    model = Struc2Vec(graph=G, walk_length=10, num_walks=100, workers=4, verbose=0, stay_prob=0.3, 
        opt1_reduce_len=True,
        opt2_reduce_sim_calc=True, 
        opt3_num_layers=None, 
        temp_path='./datasets/struc2vec/temp_struc2vec{}/'.format(cliend_id), reuse=False)


    model.train(embed_size=64, window_size=5, workers=3, iter=5)
    embeddings = model.get_embeddings()
    x = np.empty([0, 64])
    for key, value in embeddings.items():
        value = np.expand_dims(value, axis=0)
        x = np.append(x, value, axis=0)

    sims = []
    for i in range(num_edges):
        edge_u = edge_index[0][i]
        edge_v = edge_index[1][i]
        sim = 1 - distance.cosine(x[edge_u], x[edge_v])
        sims.append(sim)
        distribution_sim = []
        for v in sorted(set(sims)):
            distribution_sim.append(sims.count(v))
    return distribution_sim


        



