

import sys
import math
from tqdm import tqdm
import random
import numpy as np
import scipy.sparse as ssp
from scipy.sparse.csgraph import shortest_path
import torch
from torch_sparse import spspmm
import torch_geometric
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import (negative_sampling, add_self_loops,
                                   train_test_split_edges)
import pdb
import networkx as nx
import math
from math import ceil
from concurrent.futures import ProcessPoolExecutor, as_completed


def CN(A, edge_index, batch_size=100000):
    # The Common Neighbor heuristic score.
    link_loader = DataLoader(range(edge_index.size(1)), batch_size)
    scores = []
    for ind in tqdm(link_loader):
        src, dst = edge_index[0, ind], edge_index[1, ind]
        cur_scores = np.array(np.sum(A[src].multiply(A[dst]), 1)).flatten()
        scores.append(cur_scores)
        # print('max cn: ', np.concatenate(scores, 0).max())

    return torch.FloatTensor(np.concatenate(scores, 0))


def AA(A, edge_index, batch_size=100000):
    # The Adamic-Adar heuristic score.
    multiplier = 1 / np.log(A.sum(axis=0))
    multiplier[np.isinf(multiplier)] = 0
    A_ = A.multiply(multiplier).tocsr()
    link_loader = DataLoader(range(edge_index.size(1)), batch_size)
    scores = []
    for ind in tqdm(link_loader):
        src, dst = edge_index[0, ind], edge_index[1, ind]
        cur_scores = np.array(np.sum(A[src].multiply(A_[dst]), 1)).flatten()
        scores.append(cur_scores)
    scores = np.concatenate(scores, 0)
    
    return torch.FloatTensor(scores)

def RA(A, edge_index, batch_size=100000):
    # The Adamic-Adar heuristic score.
    multiplier = 1 / (A.sum(axis=0))
    multiplier[np.isinf(multiplier)] = 0
    A_ = A.multiply(multiplier).tocsr()
    link_loader = DataLoader(range(edge_index.size(1)), batch_size)
    scores = []
    for ind in tqdm(link_loader):
        src, dst = edge_index[0, ind], edge_index[1, ind]
        cur_scores = np.array(np.sum(A[src].multiply(A_[dst]), 1)).flatten()
        scores.append(cur_scores)
    scores = np.concatenate(scores, 0)
    return torch.FloatTensor(scores)

def PPR(A, edge_index):
    # The Personalized PageRank heuristic score.
    # Need install fast_pagerank by "pip install fast-pagerank"
    # Too slow for large datasets now.
    from fast_pagerank import pagerank_power
    num_nodes = A.shape[0]
    src_index, sort_indices = torch.sort(edge_index[0])
    dst_index = edge_index[1, sort_indices]
    edge_index = torch.stack([src_index, dst_index])
    #edge_index = edge_index[:, :50]
    scores = []
    visited = set([])
    j = 0
    for i in tqdm(range(edge_index.shape[1])):
        if i < j:
            continue
        src = edge_index[0, i]
        personalize = np.zeros(num_nodes)
        personalize[src] = 1
        ppr = pagerank_power(A, p=0.85, personalize=personalize, tol=1e-7)
        j = i
        while edge_index[0, j] == src:
            j += 1
            if j == edge_index.shape[1]:
                break
        all_dst = edge_index[1, i:j]
        cur_scores = ppr[all_dst]
        if cur_scores.ndim == 0:
            cur_scores = np.expand_dims(cur_scores, 0)
        scores.append(np.array(cur_scores))

    scores = np.concatenate(scores, 0)
    return torch.FloatTensor(scores)

def _bidirectional_capped_shorest_path(G, source, target, k):
    # return the length of shortest path, return -1 if it is larger than k
    # does BFS from both source and target and meets in the middle
    if target == source:
        return ({target: None}, {source: None}, source)

    # handle either directed or undirected
    if G.is_directed():
        Gpred = G.pred
        Gsucc = G.succ
    else:
        Gpred = G.adj
        Gsucc = G.adj

    # predecesssor and successors in search
    pred = {source: None}
    succ = {target: None}

    # initialize fringes, start with forward
    forward_fringe = [source]
    reverse_fringe = [target]
    cur_len = 0

    while forward_fringe and reverse_fringe:
        cur_len += 1
        if len(forward_fringe) <= len(reverse_fringe):
            this_level = forward_fringe
            forward_fringe = []
            for v in this_level:
                for w in Gsucc[v]:
                    if w not in pred:
                        forward_fringe.append(w)
                        pred[w] = v
                    if w in succ:  # path found
                        return cur_len
        else:
            this_level = reverse_fringe
            reverse_fringe = []
            for v in this_level:
                for w in Gpred[v]:
                    if w not in succ:
                        succ[w] = v
                        reverse_fringe.append(w)
                    if w in pred:  # found path
                        return cur_len 
        if cur_len > k:
            break

    return -1


def process_chunk(args):
    """
    Process a chunk of edge pairs
    """
    chunk, G, k, remove = args
    local_scores = []
    
    for s, t in chunk:
        if s == t:
            local_scores.append(999)
            continue
            
        add_flag1 = add_flag2 = 0
        
        
        # Calculate shortest path
        sp = _bidirectional_capped_shorest_path(G, s, t, k)
        sp = 999 if sp < 0 else sp
        
        # Add back removed edges
            
        local_scores.append(1/sp)
    
    return local_scores


def capped_shortest_path_mat(A, edge_index, k=2, batch_size = 1000000):
    if k == 2:
        link_loader = DataLoader(range(edge_index.size(1)), batch_size)
        scores = []
        for ind in tqdm(link_loader):
            src, dst = edge_index[0, ind], edge_index[1, ind]
            cur_scores = np.array(np.sum(A[src].multiply(A[dst]), 1)).flatten()
            cur_scores[cur_scores>0] = 0.5
            src_l = src.tolist()
            dst_l = dst.tolist()
            mask = np.array(A[(src_l, dst_l)]).flatten() > 0
            cur_scores[mask] = 1

            scores.append(cur_scores)
        # print('max cn: ', np.concatenate(scores, 0).max())

        return torch.FloatTensor(np.concatenate(scores, 0))

    # convert a to sparse tensor
    print('computing A2..')
    A2 = A.dot(A)
    if k > 4:
        print('computing A3..')
        A3 = A2.dot(A)
    else:
        A3 = None

    link_loader = DataLoader(range(edge_index.size(1)), batch_size)
    score_list = []
    for ind in tqdm(link_loader):

        src, dst = edge_index[0, ind], edge_index[1, ind]
        cur_scores = (999)*np.ones_like(src).astype('float')
        src_l = src.tolist()
        dst_l = dst.tolist()

        if k >= 6: 
            scores = np.array(np.sum(A3[src].multiply(A3[dst]), 1)).flatten()
            cur_scores[scores>0] = 6
        if k >= 5:
            scores = np.array(np.sum(A2[src].multiply(A3[dst]), 1)).flatten()
            cur_scores[scores>0] = 5
        if k >= 4:
            scores = np.array(np.sum(A2[src].multiply(A2[dst]), 1)).flatten()
            cur_scores[scores>0] = 4
        if A3 is not None:
            mask = np.array(A3[src_l, dst_l]).flatten() > 0
            cur_scores[mask] = 3
        else:
            scores = np.array(np.sum(A[src].multiply(A2[dst]), 1)).flatten()
            cur_scores[scores>0] = 3

        mask = np.array(A2[(src_l, dst_l)]).flatten() > 0
        cur_scores[mask] = 2
        mask = np.array(A[(src_l, dst_l)]).flatten() > 0
        cur_scores[mask] = 1
        cur_scores[src==dst] = 0.5

        #TODO if A3[s,t] > 0, scores = 3
        # if A2[s,t] > 0, scores = 2
        # if A[s,t] > 0 scores = 1
        score_list.append(cur_scores)

    return 1/torch.FloatTensor(np.concatenate(score_list, 0))

def csr_to_sparse_tensor(csr_mat, device='cuda'):
    """
    Convert a scipy.sparse.csr_matrix to torch.sparse.FloatTensor.
    
    Args:
        csr_mat: scipy.sparse.csr_matrix - Input CSR matrix
        device: str - Target device ('cuda' or 'cpu')
    
    Returns:
        torch.sparse.FloatTensor - Sparse tensor on specified device
    """
    # Ensure the input is in CSR format

    #if not isinstance(csr_mat, csr_matrix):
    #    try:
    #        csr_mat = csr_mat.tocsr()
    #    except AttributeError:
    #        raise ValueError("Input matrix must be scipy.sparse.csr_matrix or convertible to CSR")
    
    # Get the indices and values
    row_indices = csr_mat.indptr
    col_indices = csr_mat.indices
    values = csr_mat.data
    
    # Convert to COO format (which PyTorch expects)
    row = np.repeat(np.arange(len(row_indices) - 1), np.diff(row_indices))
    
    # Stack indices into a 2xN matrix
    indices = torch.from_numpy(np.vstack((row, col_indices))).long()
    values = torch.from_numpy(values).float()
    size = torch.Size(csr_mat.shape)
    
    # Create the sparse tensor and move to specified device
    sparse_tensor = torch.sparse_coo_tensor(indices, values, size)
    return sparse_tensor.to(device)

def capped_shortest_path_gpu(A_scipy, edge_index, k=6, batch_size=100000):
    """
    Compute shortest path matrix using PyTorch operations on GPU.
    
    Args:
        A: torch.sparse.FloatTensor - Adjacency matrix
        edge_index: torch.LongTensor - Edge indices
        k: int - Maximum path length to consider
        batch_size: int - Batch size for processing
    
    Returns:
        torch.FloatTensor - Inverse of shortest path lengths
    """
    # Ensure inputs are on GPU
    A2 = A_scipy.dot(A_scipy)
    A3 = A2.dot(A_scipy)

    A = csr_to_sparse_tensor(A_scipy) 
    A2 = csr_to_sparse_tensor(A3)
    A3 = csr_to_sparse_tensor(A3)

    device = 'cuda' 
    A = A.to(device)
    A2 = A2.to(device)
    A3 = A3.to(device)
   
    # Compute powers of adjacency matrix using sparse operations
    
    link_loader = DataLoader(range(edge_index.size(1)), batch_size)
    score_list = []
    
    for ind in tqdm(link_loader):
        src, dst = edge_index[0, ind], edge_index[1, ind]
        
        # Initialize scores with large values
        cur_scores = torch.full((len(src),), 999.0, device=device)
        
        if k >= 6:
            # Convert to dense for multiplication as PyTorch sparse doesn't support element-wise mult
            scores = torch.sum(A3[src].to_dense() * A3[dst].to_dense(), dim=1)
            cur_scores[scores > 0] = 6
            
        if k >= 5:
            scores = torch.sum(A2[src].to_dense() * A3[dst].to_dense(), dim=1)
            cur_scores[scores > 0] = 5
            
        if k >= 4:
            scores = torch.sum(A2[src].to_dense() * A2[dst].to_dense(), dim=1)
            cur_scores[scores > 0] = 4
        
        # Direct lookups for shorter paths
        src_l, dst_l = src.tolist(), dst.tolist()
        mask = A3[src_l, dst_l].to_dense().squeeze() > 0
        cur_scores[mask] = 3
        
        mask = A2[src_l, dst_l].to_dense().squeeze() > 0
        cur_scores[mask] = 2
        
        mask = A[src_l, dst_l].to_dense().squeeze() > 0
        cur_scores[mask] = 1
        
        score_list.append(cur_scores.cpu())  # Move to CPU for list storage
    
    # Combine all scores and return inverse
    return 1 / torch.cat(score_list, 0)

def capped_shortest_path(A, edge_index, k=6, remove=False, num_workers=4):
    
    G = nx.from_scipy_sparse_matrix(A)
    add_flag1 = 0
    add_flag2 = 0
    count = 0
    count1 = count2 = 0
    if True: 
    
        scores = []
        for i in tqdm(range(edge_index.size(1))):
            s = edge_index[0][i].item()
            t = edge_index[1][i].item()
            if s == t:
                count += 1
                scores.append(999)
                continue

        # if (s,t) in train_pos_list: train_pos_list.remove((s,t))
        # if (t,s) in train_pos_list: train_pos_list.remove((t,s))


        # G = nx.Graph(train_pos_list)
    
            sp = _bidirectional_capped_shorest_path(G, s, t, k)
            if sp < 0:
                sp = 999
        

            if add_flag1 == 1: 
                G.add_edge(s,t)
                add_flag1 = 0

            if add_flag2 == 1: 
                G.add_edge(t, s)
                add_flag2 = 0
    

            scores.append(1/(sp))
        print('equal number: ', count)
        print('count1: ', count1)
        print('count2: ', count2)
        return torch.FloatTensor(scores)
   
    else:
        
        scores_dict = {}
        # Create tasks list
        print("divide tasks")
        tasks = [
             (edge_index[0][i].item(), edge_index[1][i].item())
             for i in tqdm(range(edge_index.size(1)))
        ]
        # Split tasks into chunks for better performance
        chunk_size = ceil(len(tasks) / num_workers)
        chunked_tasks = [tasks[i:i + chunk_size] for i in range(0, len(tasks), chunk_size)]


        # Prepare arguments for parallel processing
        chunk_args = [(chunk, G, k, remove) for chunk in chunked_tasks]

        # Execute in parallel
        scores = []
        print("start executing")
        with ProcessPoolExecutor(max_workers=num_workers) as executor:
            results = list(executor.map(process_chunk, chunk_args))
            for chunk_result in results:
                scores.extend(chunk_result)

    return torch.FloatTensor(scores)


def shortest_path(A, edge_index, remove=False):
    
    scores = []
    G = nx.from_scipy_sparse_matrix(A)
    add_flag1 = 0
    add_flag2 = 0
    count = 0
    count1 = count2 = 0
    print('remove: ', remove)
    for i in range(edge_index.size(1)):
        s = edge_index[0][i].item()
        t = edge_index[1][i].item()
        if s == t:
            count += 1
            scores.append(999)
            continue

        # if (s,t) in train_pos_list: train_pos_list.remove((s,t))
        # if (t,s) in train_pos_list: train_pos_list.remove((t,s))


        # G = nx.Graph(train_pos_list)
        if remove:
            if (s,t) in G.edges: 
                G.remove_edge(s,t)
                add_flag1 = 1
                count1 += 1
            if (t,s) in G.edges: 
                G.remove_edge(t,s)
                add_flag2 = 1
                count2 += 1

        if nx.has_path(G, source=s, target=t):

            sp = nx.shortest_path_length(G, source=s, target=t)
            # if sp == 0:
            #     print(1)
        else:
            sp = 999
        

        if add_flag1 == 1: 
            G.add_edge(s,t)
            add_flag1 = 0

        if add_flag2 == 1: 
            G.add_edge(t, s)
            add_flag2 = 0
    

        scores.append(1/(sp))
    print('equal number: ', count)
    print('count1: ', count1)
    print('count2: ', count2)

    return torch.FloatTensor(scores)

def katz_apro(A, edge_index, beta=0.005, path_len=3, remove=False):
    scores = []
    G = nx.from_scipy_sparse_matrix(A)
    path_len = int(path_len)
    count = 0
    add_flag1 = 0
    add_flag2 = 0
    count1 = count2 = 0
    betas = np.zeros(path_len)
    print('remove: ', remove)
    for i in range(len(betas)):
        betas[i] = np.power(beta, i+1)
    
    for i in range(edge_index.size(1)):
        s = edge_index[0][i].item()
        t = edge_index[1][i].item()

        if s == t:
            count += 1
            scores.append(0)
            continue
        
        if remove:
            if (s,t) in G.edges: 
                G.remove_edge(s,t)
                add_flag1 = 1
                count1 += 1
                
            if (t,s) in G.edges: 
                G.remove_edge(t,s)
                add_flag2 = 1
                count2 += 1


        paths = np.zeros(path_len)
        for path in nx.all_simple_paths(G, source=s, target=t, cutoff=path_len):
            paths[len(path)-2] += 1  
        
        kz = np.sum(betas * paths)

        scores.append(kz)
        
        if add_flag1 == 1: 
            G.add_edge(s,t)
            add_flag1 = 0

        if add_flag2 == 1: 
            G.add_edge(t, s)
            add_flag2 = 0
        
    print('equal number: ', count)
    print('count1: ', count1)
    print('count2: ', count2)

    return torch.FloatTensor(scores)


def katz_close(A, edge_index, beta=0.005):

    scores = []
    G = nx.from_scipy_sparse_matrix(A)

    adj = nx.adjacency_matrix(G, nodelist=range(len(G.nodes)))
    aux = adj.T.multiply(-beta).todense()
    np.fill_diagonal(aux, 1+aux.diagonal())
    sim = np.linalg.inv(aux)
    np.fill_diagonal(sim, sim.diagonal()-1)

    for i in range(edge_index.size(1)):
        s = edge_index[0][i].item()
        t = edge_index[1][i].item()

        scores.append(sim[s,t])

    
    return torch.FloatTensor(scores)

