import numpy as np
import random

def sample_k(begin, end, k):

    if begin >= end:
        return np.zeros(end)
    if end - begin <= k:
        choice = np.ones(end)
        return choice
    choice = np.zeros(end)
    ran = range(begin, end)
    sample = random.sample(ran, k)
    for k in sample:
        choice[k] = 1

    return choice

res = np.load("edge_result.npy")
pred = np.load("edge_pred.npy")

sorted_pred = sorted(enumerate(-pred), key=lambda x: x[1])
idx = [i[0] for i in sorted_pred]
nums = [i[1] for i in sorted_pred]
nums = np.array(nums)
res_order = -nums

total = np.sum(res)
print("total number of triangles", total)

num_edge = len(pred)
print("number of total edges", num_edge)
num_times = 50
norm = 1

bucket = 5000
bucket_heavy = int(0.1 * bucket)

print("number of sampling edges", bucket)

total = 0

for i in range(num_edge):
    total += res[i]

sum_heavy = 0
sum_light = 0

sampling_rate = min(1, (bucket - bucket_heavy) / (num_edge - bucket_heavy))

ans_learned = 0

for k in range(num_times):
    sum_heavy = 0
    sum_light = 0
    choice = sample_k(bucket_heavy, num_edge, bucket - bucket_heavy)
    for i in range(num_edge):
        if i < bucket_heavy:
            sum_heavy += res[idx[i]]
        else:
            # if (random.uniform(0, 1) > sampling_rate):
            #     continue
            if choice[i] == 0:
                continue
            sum_light += res[idx[i]]

    ans_learned += abs(sum_heavy + sum_light / sampling_rate - total)
ans_learned /= num_times

print("error of learned", ans_learned / total)
