import math
import time
import numpy as np

import range_tree
import preprocessing
import find_patterns

def relative_error(estimated, true, n):
    if true == 0:
        return abs(estimated - true)
    return abs(estimated - true) / max(true, 0.001 * n)

def query_true(n, d, Q, test_input_nodes, logger):
    d2 = d << 1
    mag = 1
    root = None
    for attr in test_input_nodes:
        root = range_tree.insertSplit(root, attr, 0, n - 1, 0, n, d2 - 1, mag)
    logger.info('finish building range tree dynamically for query true')
    true = []
    for q in Q:
        (root, weight, noise) = range_tree.querySplit(root, 0, n - 1, q, 0, n, d2 - 1, mag)
        true.append(weight)
    return true
    
def pure_DP(n, d, Q, eps, test_input_nodes, pattern, logger):
    d2 = d << 1
    mag = preprocessing.pure_DP_mag(n, (math.ceil(math.log2(n)) + 1)**(d2), eps, pattern)
    root = None
    for attr in test_input_nodes:
        root = range_tree.insertSplit(root, attr, 0, n - 1, 0, n, d2 - 1, mag)
    logger.info('finish building range tree dynamically for pure DP')

    sum_err = 0.0
    sum_err2 = 0.0
    for q in Q:
        (root, weight, noise) = range_tree.querySplit(root, 0, n - 1, q, 0, n, d2 - 1, mag)
        """
        # verify the correctness
        true_weight = 0
        for attr in test_input_nodes:
            in_range = True
            for i in range(d2):
                if not (q[i][0] <= attr[i] <= q[i][1]):
                    in_range = False
                    break
            if in_range: 
                true_weight += attr[-1]
        print(true_weight, weight)
        #assert(true_weight == weight)              
        """  
        err = relative_error(weight + noise, weight, n) 
        sum_err += err
        sum_err2 += err * err
    logger.info('finish answering queries for pure DP')
    return sum_err, sum_err2

def pure_DP_qtime(n, d, Q, eps, test_input_nodes, pattern):
    d2 = d << 1
    mag = preprocessing.pure_DP_mag(n, (math.ceil(math.log2(n)) + 1)**(d2), eps, pattern)
    root = None
    for attr in test_input_nodes:
        root = range_tree.insertSplit(root, attr, 0, n - 1, 0, n, d2 - 1, mag)

    start = time.perf_counter()
    for q in Q:
        (root, weight, noise) = range_tree.querySplit(root, 0, n - 1, q, 0, n, d2 - 1, mag)
    end = time.perf_counter()
    return (end - start) * 1e6 / len(Q)

def pure_DP_prtime(n, d, eps, edges, h, repeat_times, pattern):
    start = time.perf_counter()
    for _ in range(repeat_times):
        d2 = d << 1
        dict = {}
        if (pattern == 'triangle'):
            deg = np.zeros(n, dtype=int)
            for u, v in edges:
                deg[u] += 1
                deg[v] += 1
            dict = find_patterns.enumerate_triangles_project(edges, deg, n, h)
        elif (pattern == '2star'):
            dict = find_patterns.enumerate_2stars_project(edges, n, h)
        else:
            dict = find_patterns.enumerate_edges_project(edges, h)
        mag = preprocessing.pure_DP_mag(n, (math.ceil(math.log2(n)) + 1)**(d2), eps, pattern)
        root = None
        for (x, y), z in dict.items():
            root = range_tree.insertSplit(root, (x, y, z), 0, n - 1, 0, n, d2 - 1, mag)
    end = time.perf_counter()
    return (end - start) / repeat_times / 60.0

def approx_DP(n, d, Q, d_max, eps, delta, test_input_nodes, pattern, logger):
    d2 = d << 1
    mag = preprocessing.approx_DP_mag(n, (math.ceil(math.log2(n)) + 1)**(d2), d_max, eps, delta, pattern)
    root = None
    for attr in test_input_nodes:
        root = range_tree.insertSplit(root, attr, 0, n - 1, 0, n, d2 - 1, mag)
    logger.info('finish building range tree dynamically for approximate DP')

    sum_err = 0.0
    sum_err2 = 0.0
    for q in Q:
        (root, weight, noise) = range_tree.querySplit(root, 0, n - 1, q, 0, n, d2 - 1, mag)
        err = relative_error(weight + noise, weight, n)
        sum_err += err
        sum_err2 += err * err
    logger.info('finish answering queries for approximate DP')
    return sum_err, sum_err2

def approx_DP_qtime(n, d, Q, d_max, eps, delta, test_input_nodes, pattern):
    d2 = d << 1
    mag = preprocessing.approx_DP_mag(n, (math.ceil(math.log2(n)) + 1)**(d2), d_max, eps, delta, pattern)
    root = None
    for attr in test_input_nodes:
        root = range_tree.insertSplit(root, attr, 0, n - 1, 0, n, d2 - 1, mag)

    start = time.perf_counter()
    for q in Q:
        (root, weight, noise) = range_tree.querySplit(root, 0, n - 1, q, 0, n, d2 - 1, mag)
    end = time.perf_counter()
    return (end - start) * 1e6 / len(Q)

def approx_DP_prtime(n, d, eps, delta, edges, h, repeat_times, pattern):
    start = time.perf_counter()
    for _ in range(repeat_times):
        d2 = d << 1
        dict = {}
        f1 = 0.0
        if (pattern == 'triangle'):
            deg = np.zeros(n, dtype=int)
            for u, v in edges:
                deg[u] += 1
                deg[v] += 1
            f1 = preprocessing.calc_f1_triangle(edges, deg, n)
            dict = find_patterns.enumerate_triangles_project(edges, deg, n, h)
        elif (pattern == '2star'):
            deg = np.zeros(n, dtype=int)
            for u, v in edges:
                deg[u] += 1
                deg[v] += 1
            f1 = preprocessing.calc_f1_2star(edges, deg, n)
            dict = find_patterns.enumerate_2stars_project(edges, n, h)
        else:
            dict = find_patterns.enumerate_edges_project(edges, h)
        mag = preprocessing.approx_DP_mag(n, (math.ceil(math.log2(n)) + 1)**(d2), f1, eps, delta, pattern)
        root = None
        for (x, y), z in dict.items():
            root = range_tree.insertSplit(root, (x, y, z), 0, n - 1, 0, n, d2 - 1, mag)
    end = time.perf_counter()
    return (end - start) / repeat_times / 60.0