import concurrent
from collections import deque

from torch_geometric.data import Data
import torch
import networkx as nx
import numpy as np
import multiprocessing as mp
import random


def generate_random_embeddings(num_nodes, embedding_dim):
    embeddings = np.random.rand(num_nodes, embedding_dim).astype(np.float32)
    return torch.tensor(embeddings)


def create_data_from_graph(G, embeddings):

    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    x = embeddings

    data = Data(x=x, edge_index=edge_index)
    return data


def compute_shortest_paths(G):
    return dict(nx.all_pairs_dijkstra_path_length(G, weight='weight'))


def compute_labels(shortest_paths, node_pairs):
    labels = []
    for u, v in node_pairs:
        labels.append(shortest_paths[u][v])
    return torch.tensor(labels, dtype=torch.float32)

def get_random_anchorset(n,c=0.5):
    m = int(np.log2(n))
    copy = int(c*m)
    anchorset_id = []
    for i in range(m):
        anchor_size = int(n/np.exp2(i + 1))
        for j in range(copy):
            anchorset_id.append(np.random.choice(n,size=anchor_size,replace=False))
    return anchorset_id

def get_dist_max(anchorset_id, dist, device):
    dist_max = torch.zeros((dist.shape[0],len(anchorset_id))).to(device)
    dist_argmax = torch.zeros((dist.shape[0],len(anchorset_id))).long().to(device)
    dist = torch.tensor(dist)
    print(dist.shape)
    for i in range(len(anchorset_id)):
        temp_id = torch.as_tensor(anchorset_id[i], dtype=torch.long)
        dist_temp = dist[:, temp_id]
        dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1)
        dist_argmax_temp=dist_argmax_temp.to(torch.device('cpu'))
        dist_max[:,i] = dist_max_temp
        dist_argmax[:,i] = temp_id[dist_argmax_temp]
    return dist_max, dist_argmax


def preselect_anchor(data, layer_num=1, anchor_num=32, anchor_size_num=4, device='cpu'):

    data.anchor_size_num = anchor_size_num
    data.anchor_set = []
    anchor_num_per_size = anchor_num//anchor_size_num
    for i in range(anchor_size_num):
        # print("i=",i)
        # print("anchor_size_num=",anchor_size_num)
        anchor_size = 2**(i+1)-1
        # print("anchor_size=",anchor_size)
        anchors = np.random.choice(data.num_nodes, size=(layer_num,anchor_num_per_size,anchor_size), replace=True)
        # print("anchors=",anchors)
        data.anchor_set.append(anchors)
    print("data.anchor_set=",data.anchor_set)
    data.anchor_set_indicator = np.zeros((layer_num, anchor_num, data.num_nodes), dtype=int)

    anchorset_id = get_random_anchorset(data.num_nodes,c=1)
    print("len(anchorset_id)=",len(anchorset_id))

    # print("data.dists=",data.dists)
    data.dists_max, data.dists_argmax = get_dist_max(anchorset_id, data.dists, device)

def single_source_shortest_path_length_range(graph, node_range, cutoff):
    dists_dict = {}
    for node in node_range:
        dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff)
    return dists_dict

def merge_dicts(dicts):
    result = {}
    for dictionary in dicts:
        result.update(dictionary)
    return result

def all_pairs_shortest_path_length_parallel(graph,cutoff=None,num_workers=4):
    nodes = list(graph.nodes)
    random.shuffle(nodes)
    if len(nodes)<50:
        num_workers = int(num_workers/4)
    elif len(nodes)<400:
        num_workers = int(num_workers/2)

    pool = mp.Pool(processes=num_workers)
    results = [pool.apply_async(single_source_shortest_path_length_range,
            args=(graph, nodes[int(len(nodes)/num_workers*i):int(len(nodes)/num_workers*(i+1))], cutoff)) for i in range(num_workers)]
    output = [p.get() for p in results]
    dists_dict = merge_dicts(output)
    pool.close()
    pool.join()
    return dists_dict

def precompute_dist_data(edge_index, num_nodes, edge_weights=None, approximate=10):
    '''
    Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
    :return:
    '''
    graph = nx.Graph()
    edge_list = edge_index.transpose(1, 0).tolist()
    graph.add_edges_from(edge_list)

    # If edge_weights is provided, assign them to the edges of the graph
    if edge_weights is not None:
        for idx, (u, v) in enumerate(graph.edges()):
            # Assign the edge weight from edge_weights (make sure they correspond correctly)
            graph[u][v]['weight'] = edge_weights[idx]  # Assuming edge_weights matches the number of edges

    n = num_nodes
    dists_array = np.zeros((n, n))

    # Use parallel shortest path length computation, considering edge weights
    dists_dict = all_pairs_shortest_path_length_parallel(graph, cutoff=approximate if approximate > 0 else None)

    for i, node_i in enumerate(graph.nodes()):
        shortest_dist = dists_dict[node_i]
        for j, node_j in enumerate(graph.nodes()):
            dist = shortest_dist.get(node_j, -1)
            if dist != -1:
                # Using the edge weights for distance computation (inverse of the path length)
                dists_array[node_i, node_j] = 1 / (dist + 1)
    return dists_array

def precompute_dist_data2(edge_index, num_nodes, edge_weights=None, approximate=10):
    '''
    Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
    :return:
    '''
    graph = nx.Graph()
    edge_list = edge_index.transpose(1, 0).tolist()
    graph.add_edges_from(edge_list)

    # If edge_weights is provided, assign them to the edges of the graph
    if edge_weights is not None:
        for idx, (u, v) in enumerate(graph.edges()):
            # Assign the edge weight from edge_weights (make sure they correspond correctly)
            graph[u][v]['weight'] = edge_weights[idx]  # Assuming edge_weights matches the number of edges

    n = num_nodes
    dists_array = np.zeros((n, n))

    # Use parallel shortest path length computation, considering edge weights
    dists_dict = all_pairs_shortest_path_length_parallel(graph, cutoff=approximate if approximate > 0 else None)

    for i, node_i in enumerate(graph.nodes()):
        shortest_dist = dists_dict[node_i]
        for j, node_j in enumerate(graph.nodes()):
            dist = shortest_dist.get(node_j, -1)
            if dist != -1:
                # Using the edge weights for distance computation (inverse of the path length)
                dists_array[node_i, node_j] =  (dist + 1)
    return dists_array

def bfs_shortest_path(G, source):
    distance = {source: 0}
    queue = deque([source])

    while queue:
        node = queue.popleft()
        for neighbor in G.neighbors(node):
            if neighbor not in distance:
                distance[neighbor] = distance[node] + 1
                queue.append(neighbor)

    return distance



def compute_required_shortest_paths(G, train_pivots, test_pivots, all_nodes):

    required_paths = {}

    for pivot in train_pivots:
        print(pivot)
        required_paths[pivot] = bfs_shortest_path(G, pivot)


    for pivot in test_pivots:
        if pivot not in required_paths:
            required_paths[pivot] = bfs_shortest_path(G, pivot)

    return required_paths
def bfs_shortest_path_from_pivot(G, pivot):
    distances = [-1] * len(G.nodes())  
    queue = deque([(pivot, 0)]) 
    distances[pivot] = 0  

    while queue:
        node, distance = queue.popleft()


        for neighbor in G.neighbors(node):
            if distances[neighbor] == -1:  
                distances[neighbor] = distance + 1  
                queue.append((neighbor, distance + 1))

    return distances


def compute_shortest_paths_from_pivots(G, pivot_nodes):
    shortest_paths = {}


    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {}


        for pivot in pivot_nodes:
            futures[pivot] = executor.submit(bfs_shortest_path_from_pivot, G, pivot)


        for pivot, future in futures.items():
            pivot_distances = future.result()


            for idx, node in enumerate(G.nodes()):
                if node not in shortest_paths:
                    shortest_paths[node] = {}

                shortest_paths[node][pivot] = pivot_distances[idx]

    return shortest_paths


def compute_labels(node_embeddings, train_samples, pivot_nodes):

    pivot_index_map = {pivot: idx for idx, pivot in enumerate(pivot_nodes)}

    labels = []

    for pivot, node in train_samples:

        pivot_index = pivot_index_map[pivot]  
        node_index = node  

        label = node_embeddings[node_index, pivot_index]  
        labels.append(label)

    return torch.tensor(labels, dtype=torch.float32)

def get_anchorset_avg_dist(anchorset_id, dist, device):
    dist_avg = torch.zeros((dist.shape[0], len(anchorset_id))).to(device)
    print(dist_avg.shape)
    dist = torch.tensor(dist)
    for i in range(len(anchorset_id)):
        temp_id = torch.as_tensor(anchorset_id[i], dtype=torch.long)
        dist_temp = dist[:, temp_id]
        dist_avg[:, i] = torch.mean(dist_temp, dim=-1)
    return dist_avg


def get_random_anchorset(n, c=1):
    m = int(np.log2(n))
    copy = int(c * m)
    anchorset_id = []
    for i in range(m):
        anchor_size = int(n / np.exp2(i + 1))
        for j in range(copy):
            anchorset_id.append(np.random.choice(n, size=anchor_size, replace=False))
    return anchorset_id


def preselect_anchor(data, layer_num=1, anchor_num=32, anchor_size_num=4, device='cpu'):
    data.anchor_size_num = anchor_size_num
    data.anchor_set = []
    anchor_num_per_size = anchor_num // anchor_size_num
    for i in range(anchor_size_num):
        anchor_size = 2**(i + 1) - 1
        anchors = np.random.choice(data.num_nodes, size=(layer_num, anchor_num_per_size, anchor_size), replace=True)
        data.anchor_set.append(anchors)
    data.anchor_set_indicator = np.zeros((layer_num, anchor_num, data.num_nodes), dtype=int)

    anchorset_id = get_random_anchorset(data.num_nodes, c=1)
    data.dists_avg = get_anchorset_avg_dist(anchorset_id, data.dists, device)
    return data
