import random

import numpy as np
import torch

from AbstractClass.TaskRelatedClasses import AbstractQueryGenerator, SupportGeneratorInterface
from pybloom_live import ScalableBloomFilter


class QueryGeneratorForExist(AbstractQueryGenerator):
    def __init__(self, ):
        AbstractQueryGenerator.__init__(self)

    def generate_train_query(self, support_x, support_y, fake_edge_tensor):
        return self.generate_query_by_real_node(support_x, support_y, fake_edge_tensor)

    def generate_test_query(self, support_x, support_y, fake_edge_tensor):
        return self.generate_query_by_real_node(support_x, support_y, fake_edge_tensor)

    @staticmethod
    def generate_query_by_real_node(support_x, support_y, fake_edge_tensor,filter_ratio=0.1):
        print('filter_ratio:',filter_ratio)
        # sample size fake edges
        sort_support_y, sort_index = support_y.sort(descending=True)
        positive_num = round(support_y.shape[0] * filter_ratio)
        if positive_num > support_y.shape[0]:
            positive_num = support_y.shape[0]
        positive_index = sort_index[:positive_num]
        start_node = int(random.random() * (fake_edge_tensor.shape[0] - positive_num))
        positive_query_x = support_x[positive_index].clone()
        negitive_query_x = fake_edge_tensor[start_node:start_node+positive_num,:]
        query_x = torch.cat((positive_query_x, negitive_query_x))
        one_label = torch.ones(positive_query_x.shape[0],device=support_x.device)
        zero_label = torch.zeros(negitive_query_x.shape[0],device=support_x.device)
        query_y = torch.cat((one_label,zero_label),dim=-1)
        query_y = query_y.view(-1,1)
        return query_x, query_y

class QueryGeneratorForDegree(AbstractQueryGenerator):
    def __init__(self, ):
        AbstractQueryGenerator.__init__(self)
        self.address_dic = None

    def set_node_address_dic(self,address_dic):
        self.address_dic = address_dic


    def generate_train_query(self, support_x, support_y, _):
        return self.generate_out_degree_query_by_statistic_for_train(support_x, support_y, _)

    def generate_test_query(self, support_x, support_y, _):
        return self.generate_out_degree_query_by_statistic(support_x, support_y, _)


    def generate_out_degree_query_by_statistic_for_train(self,support_x, support_y, fake_edge_tensor):
        # sample size fake edges
        support_x_nd = support_x.cpu().numpy()
        out_node_nd = support_x_nd[:,:support_x.shape[1]//2]
        in_node_nd = support_x_nd[:,support_x.shape[1]//2:]
        bytes_node_nd_dic = {}
        bytes_degree_dic = {}
        for i in range(support_y.shape[0]):
            node = out_node_nd[i]
            byte_node = node.tobytes()
            in_node = in_node_nd[i]
            frequency = support_y[i].item()
            if byte_node in bytes_node_nd_dic.keys():
                if in_node.tobytes() not in self.address_dic:
                    print('error')
                bytes_degree_dic[byte_node] += frequency * self.address_dic[in_node.tobytes()]
            else:
                if in_node.tobytes() not in self.address_dic:
                    print('error')
                bytes_degree_dic[byte_node] = frequency * self.address_dic[in_node.tobytes()]
                bytes_node_nd_dic[byte_node] = node
        node_list = []
        out_degree_list = []
        for byte_node in bytes_node_nd_dic.keys():
            node_list.append(bytes_node_nd_dic[byte_node])
            out_degree_list.append(bytes_degree_dic[byte_node])
        node_nd = np.array(node_list)
        query_x = torch.tensor(node_nd,device=support_x.device).float()
        query_y = torch.tensor(out_degree_list,device=support_x.device).float()
        return query_x, query_y

    @staticmethod
    def generate_out_degree_query_by_statistic(support_x, support_y, fake_edge_tensor):
        # sample size fake edges
        support_x_nd = support_x.cpu().numpy()
        out_node_nd = support_x_nd[:,:support_x.shape[1]//2]
        bytes_node_nd_dic = {}
        bytes_degree_dic = {}
        for i in range(support_y.shape[0]):
            node = out_node_nd[i]
            byte_node = node.tobytes()
            frequency = support_y[i].item()
            if byte_node in bytes_node_nd_dic.keys():
                bytes_degree_dic[byte_node] += frequency
            else:
                bytes_degree_dic[byte_node] = frequency
                bytes_node_nd_dic[byte_node] = node
        node_list = []
        out_degree_list = []
        for byte_node in bytes_node_nd_dic.keys():
            node_list.append(bytes_node_nd_dic[byte_node])
            out_degree_list.append(bytes_degree_dic[byte_node])
        node_nd = np.array(node_list)
        query_x = torch.tensor(node_nd,device=support_x.device).float()
        query_y = torch.tensor(out_degree_list,device=support_x.device).float().view(-1,1)

        return query_x, query_y

class SimpleQueryGenerator(AbstractQueryGenerator):
    def __init__(self, ):
        AbstractQueryGenerator.__init__(self)

    def generate_train_query(self, support_x, support_y, stream_node_vec_list):
        return support_x, support_y.view(-1,1)

    def generate_test_query(self, support_x, support_y, stream_node_vec_list):
        return support_x, support_y.view(-1,1)
