import numpy as np
import random

def sample_k(begin, end, k):

    if begin >= end:
        return np.zeros(end)
    if end - begin <= k:
        choice = np.ones(end)
        return choice
    choice = np.zeros(end)
    ran = range(begin, end)
    sample = random.sample(ran, k)
    for k in sample:
        choice[k] = 1

    return choice

np.set_printoptions(threshold=100000)

Edge = np.zeros((500000, 2))
edge = {}

vertex = []
for i in range(200000):
    vertex.append([])
deg = np.zeros(200000)

res = np.load("edge_result.npy")
order = np.load("vertex_order.npy")
total = np.sum(res)
print("total number of triangles", total)

num_edge = len(res)
print("number of total edges", num_edge)


num_times = 50

bucket = 1000

print("number of sampling edges", bucket)

total = 0

for i in range(num_edge):
    total += res[i]


test_graph = "graph2.txt"
test_file = open(test_graph)

idx = 0

num_node = 0

for line in test_file:

    line = line.split("	")
    l0 = int(line[0])
    l1 = int(line[1])
    if l0 == l1:
        continue
    if(l0 >= l1):
        tmp = l0
        l0 = l1
        l1 = tmp
    num_node = max(num_node, l1)
    if (l0, l1) in edge :
        continue

    Edge[idx][0] = int(l0)
    Edge[idx][1] = int(l1)

    edge[(l0, l1)] = 1
    edge[(l1, l0)] = 1
    vertex[l0].append(l1)
    vertex[l1].append(l0)
    deg[l0] += 1
    deg[l1] += 1

    idx += 1
if deg[0] >= 1: num_node += 1

sampling_rate = min(1, bucket / num_edge)
ans_mvv = 0

for k in range(num_times):

    sample = sample_k(0, num_edge, bucket)
    sampleMap = {}

    for i in range(num_edge):

        l0 = int(Edge[i][0])
        l1 = int(Edge[i][1])
        if sample[i] == 1:
            sampleMap[(l0, l1)] = 1
            sampleMap[(l1, l0)] = 1

    sum_heavy = 0
    sum_light = 0
    for i in range(num_edge):

        if res[i] >= 50 / sampling_rate: # consider for saving time

            l0 = int(Edge[i][0])
            l1 = int(Edge[i][1])

            if order[l0 - 1] >= order[l1 - 1]:
                tmp = l0
                l0 = l1
                l1 = tmp

            list = vertex[l0]
            est = 0
            for t in range(len(list)):
                k = list[t]
                if order[k - 1] <= order[l0 - 1] or order[k - 1] >= order[l1 - 1]:
                    continue
                if (l0, k) in sampleMap and (k, l1) in edge:
                    est += 1
            if est >= 100: # here the threshold is eps = 0.1

                sum_heavy += est / sampling_rate
                continue

        if sample[i] == 1:
            sum_light += res[i]
            continue

    ans_mvv += abs(sum_heavy + sum_light / sampling_rate - total)
ans_mvv /= num_times

print("error of mvv", ans_mvv / total)
