import numpy as np
import random

def sample_k(begin, end, k):
    if begin >= end:
        return np.zeros(end)
    if end - begin <= k:
        choice = np.ones(end)
        return choice
    choice = np.zeros(end)
    ran = range(begin, end)
    sample = random.sample(ran, k)
    for k in sample:
        choice[k] = 1

    return choice

res = np.load("edge_result.npy")
pred = np.load("edge_pred.npy")

sorted_pred = sorted(enumerate(-pred), key=lambda x: x[1]) #  res
idx = [i[0] for i in sorted_pred]
nums = [i[1] for i in sorted_pred]
nums = np.array(nums)
res_order = -nums

total = np.sum(res)
print("total number of triangles", total)

num_edge = len(pred)
print("number of total edges", num_edge)
num_times = 50


bucket = 1000
bucket_heavy = int(0.1 * bucket)

print("number of sampling edges", bucket)

total = 0

for i in range(num_edge):
    total += res[i]

thre = 0
for i in range(num_edge):
    if pred[idx[i]] < 5: # we should set this parameter for other datasets
        thre = i
        break

print("number of edge about the threshold", thre)
sum_heavy = 0
sum_light = 0

sampling_rate1 = min(1, 0.2 * bucket / (thre - bucket_heavy))
sampling_rate2 = min(1, (0.7 * bucket) / (num_edge - thre))
sampling_rate = (bucket - bucket_heavy) / (num_edge - bucket_heavy)

ans_learned = 0

for k in range(num_times):
    sum_heavy = 0
    sum_light = 0
    sum_medium = 0
    medium_choice = sample_k(bucket_heavy, thre, int(0.2 * bucket))
    light_choice = sample_k(thre, num_edge, int(0.7 * bucket))
    for i in range(num_edge):
        if i < bucket_heavy:
            sum_heavy += res[idx[i]]

        elif i < thre:
            # if (random.uniform(0, 1) > sampling_rate1):
            #     continue
            if medium_choice[i] == 0:
                continue
            sum_medium += res[idx[i]]
        else:
            # if (random.uniform(0, 1) > sampling_rate2):
            #     continue
            if light_choice[i] == 0:
                continue
            sum_light += res[idx[i]]

    ans_learned += abs(sum_heavy + sum_medium / sampling_rate1 + sum_light / sampling_rate2 - total)
ans_learned /= num_times

print("error of learned", ans_learned / total)
