import numpy as np
import torch
from pybloom_live import ScalableBloomFilter

from SourceCode.ModelModule.MemoryMatrix import SparseBasicMemoryMatrixAndCM12, AccelerateBasicMemoryMatrixAndCM12, \
    SparseBasicMemoryMatrixAndCM12ForDegree, AccelerateBasicMemoryMatrixAndCM12ForDegree


def convert_base_sketch_sparse(model):
    memory_matrix = model.memory_matrix
    sparse_memory_matrix = SparseBasicMemoryMatrixAndCM12(memory_matrix.row_dim,memory_matrix.col_dim,
                                                          memory_matrix.edge_embedding_dim,
                                                          memory_matrix.device)
    model.memory_matrix = sparse_memory_matrix
    return model

def convert_degree_sketch_sparse(model):
    memory_matrix = model.memory_matrix
    sparse_memory_matrix = SparseBasicMemoryMatrixAndCM12ForDegree(memory_matrix.row_dim,memory_matrix.col_dim,
                                                          memory_matrix.edge_embedding_dim,
                                                          memory_matrix.device)
    model.memory_matrix = sparse_memory_matrix
    return model

def convert_degree_sketch_accelerate(model):
    print('accelerate memory')
    memory_matrix = model.memory_matrix
    sparse_memory_matrix = AccelerateBasicMemoryMatrixAndCM12ForDegree(memory_matrix.row_dim,memory_matrix.col_dim,
                                                          memory_matrix.edge_embedding_dim,
                                                          memory_matrix.device)
    model.memory_matrix = sparse_memory_matrix
    return model

def convert_base_sketch_accelerate(model):
    print('accelerate memory')
    memory_matrix = model.memory_matrix
    sparse_memory_matrix = AccelerateBasicMemoryMatrixAndCM12(memory_matrix.row_dim,memory_matrix.col_dim,
                                                          memory_matrix.edge_embedding_dim,
                                                          memory_matrix.device)
    model.memory_matrix = sparse_memory_matrix
    return model

def extract_node(train_queries_nd):
    new_nd = train_queries_nd.copy()
    node_nd = new_nd.reshape(new_nd.shape[0] * 2, new_nd.shape[1] // 2)
    return node_nd

def filter_node(node_nd):
    node_set = set()
    filtered_node_nd_list = []
    for node in node_nd:
        if node.tobytes() not in node_set:
            node_set.add(node.tobytes())
            filtered_node_nd_list.append(node)
    filtered_node_nd = np.array(filtered_node_nd_list)
    return filtered_node_nd


# fake_edge_num_ratio = num(fake_edges)/num(real_edges)
# generate fake_edge_num_ratio fake_edges and then filter the exist_edge
# input all tensor
# output tensor
def generate_fake_edge_tensor(train_queries_nd, node_nd, device, fake_edge_num_ratio=10):
    # print('start generating... fake edge')
    exist_edge_num = train_queries_nd.shape[0]
    bf = ScalableBloomFilter(initial_capacity=exist_edge_num * fake_edge_num_ratio, error_rate=0.01,
                             mode=ScalableBloomFilter.LARGE_SET_GROWTH)
    for edge_id in train_queries_nd:
        bf.add(edge_id.tobytes())

    rng = np.random.default_rng()
    sample_node_index = rng.integers(0, node_nd.shape[0], int(exist_edge_num * fake_edge_num_ratio)*2)
    sample_node = node_nd[sample_node_index]
    # 两倍的量
    sample_node = sample_node.reshape((sample_node.shape[0] // 2, -1))
    # 不存在的边的下标
    not_exsist_index = []
    for i in range(sample_node.shape[0]):
        if not bf.add(sample_node[i].tobytes()):
            not_exsist_index.append(i)

    not_exist_edge = sample_node[not_exsist_index]

    # print('actual ratio: ',not_exist_edge.shape[0]/exist_edge_num)
    # print('end generating... fake edge')

    return torch.tensor(not_exist_edge, device=device).float(),[i for i in range(not_exist_edge.shape[0])]


def merge_task_file(queries_np_list, counts_np_list, start_pos=None, end_pos=None, merge_num=None):
    # print('start merging file...')
    if start_pos is None and end_pos is None and merge_num is None:
        start_pos = 0
        end_pos = len(queries_np_list)
    elif start_pos is not None and end_pos is None and merge_num is not None:
        end_pos = start_pos + merge_num

    # safety check
    assert len(queries_np_list) == len(counts_np_list), 'error, query and count list must have the same length'
    assert start_pos >= 0 and end_pos <= len(queries_np_list), 'error, pos out of the bound'
    assert start_pos < end_pos, 'error, end pos <= start pos'
    bytes_counts_dic = {}
    bytes_ndarray_dic = {}

    # statistic
    for i in range(start_pos, end_pos):
        for j in range(queries_np_list[i].shape[0]):
            item_nd = queries_np_list[i][j]
            count = counts_np_list[i][j]
            item_byte = item_nd.tobytes()
            if item_byte in bytes_counts_dic.keys():
                bytes_counts_dic[item_byte] += count

            else:
                if item_byte in bytes_ndarray_dic.keys():
                    print('dic error!')
                    exit(-1)
                bytes_ndarray_dic[item_byte] = item_nd
                bytes_counts_dic[item_byte] = count

    # merge all item in dic
    item_nd_list = []
    counts_list = []
    for key in bytes_counts_dic.keys():
        item_nd_list.append(bytes_ndarray_dic[key])
        counts_list.append(bytes_counts_dic[key])
    res_item_nd = np.array(item_nd_list)
    res_counts_nd = np.array(counts_list)
    # print('end merging file...')
    return res_item_nd, res_counts_nd
