import numpy as np
import torch
import scipy.sparse as scsp
import numba
from ppr_solver import *
import warnings

warnings.filterwarnings("ignore")

graph_names = ['Cora', 'Citeseer', 'ogbn-arxiv', 'as-skitter', 'ogbn-proteins', 'com-orkut', 'cit-patent', 'ogbl-ppa',
               'ogbn-products', 'wiki-talk', 'com-youtube', 'ogbn-mag', 'soc-lj1', 'reddit', 'pubmed', 'wiki-en21',
               'com-friendster', 'ogbn-papers100M']
path = './dataset/'

alpha = 0.1
alpha = alpha / (2 - alpha)
mu = (1. - alpha) / (1. + alpha)
omega = 1. + (mu / (1. + np.sqrt(1. - mu ** 2.))) ** 2.
alpha = 2 * alpha / (1 + alpha)

all_result = {}

for graph_name in graph_names:
    if graph_name == 'ogbn-papers100M' or graph_name == 'com-friendster':
        max_test_num = 8
    else:
        max_test_num = 30
    graph_path = path + graph_name + '/'
    adj_matrix = scsp.load_npz(graph_path + graph_name + '_csr-mat.npz')
    indices = adj_matrix.indices
    indptr = adj_matrix.indptr
    n = len(indptr) - 1
    m = len(indices)
    degree = np.array(adj_matrix.sum(1)).flatten()
    eps = 1e-10 / (m + n)

    degree_dic = {}
    for i in range(n):
        degree_dic[i] = degree[i]
    degree_dic = list(degree_dic.items())
    degree_dic = np.array(sorted(degree_dic, key=lambda x:x[1]))
    np.random.seed(17)
    s_nodes = list(degree_dic[:max_test_num,0])
    s_nodes += list(degree_dic[-max_test_num:,0])
    s_nodes += list(degree_dic[int(n / 2 - max_test_num / 2):int(n / 2 + max_test_num / 2), 0])
    s_nodes += list(np.random.randint(n, size=max_test_num))
    s_nodes = np.array(list(np.unique(s_nodes).astype(np.int32)))
    print(graph_name, end=' ')
    graph_result = solve_a_graph_gpu(n, indptr, indices, degree, alpha, eps, s_nodes, local=True)
    all_result[graph_name] = (n, degree.mean(), graph_result)
    print(graph_result[1].mean())
    np.save('./results/ppr_exp_pk_result_new.npy', all_result)