import os
import time
import matplotlib.pyplot as plt
from pybloom_live import ScalableBloomFilter, BloomFilter
from AbstractClass.TaskRelatedClasses import SupportGeneratorInterface, \
    AbstractDecoratorSupportGenerator
import pandas as pd
import numpy as np
import copy
import random
import torch
from Utils.Util import merge_task_file, extract_node, generate_fake_edge_tensor, filter_node


class ConstructionAnonymousSupportGenerator(SupportGeneratorInterface):

    def __init__(self, data_path, input_dim=10,
                 item_lower=None, item_upper=None, train_dir_name="train_task_files", test_dir_name="test_task_files",
                 generate_fake_edge=False, total_step=None, step_scale=None):
        super().__init__()
        # path for load all data
        self.data_path = data_path
        self.train_dir_name = train_dir_name
        self.test_dir_name = test_dir_name
        # index for shuffle
        self.train_index = None
        # for shift sample func
        self.next_sample_func = None
        self.shift_step = None
        # function for sampling train support
        self.sample_train_support_func = self.sample_train_support_by_itemsize
        # store all data file used to generate train task
        self.num_of_train_file = None
        self.num_of_test_file = None
        self.test_file_ave_length = None
        self.test_stream_length = None
        self.train_file_ave_length = None
        self.train_stream_length = None
        self.train_task_file_queries_np_list = []
        self.train_task_file_counts_np_list = []
        self.test_task_file_queries_np_list = []
        self.test_task_file_counts_np_list = []
        self.train_task_file_queries_tensor_list = []
        self.train_task_file_counts_tensor_list = []
        self.test_task_file_queries_tensor_list = []
        self.test_task_file_counts_tensor_list = []
        self.train_queries_nd = None
        self.train_counts_nd = None
        self.train_node_np = None
        self.filtered_train_node_np = None
        self.train_queries_tensor = None
        self.train_counts_tensor = None
        self.train_node_tensor = None
        self.fake_edge_tensor = None
        self.fake_edge_index = None
        # for checking data format
        self.input_dimension = input_dim
        # for constructing exist
        self.generate_fake_edge = generate_fake_edge
        # generate 2 times fake edge to real edge
        self.fake_edge_num_ratio = 2
        self.train_item_pos = 0
        self.item_lower = item_lower
        self.item_upper = item_upper
        self.rng = np.random.default_rng()
        self.curr_step = 0
        self.total_step = total_step
        self.step_scale = step_scale
        self.check_set = set()
        if self.item_upper is not None and self.item_lower is not None:
            assert self.item_lower < self.item_upper or self.item_upper == -1, \
                'item upper should not be smaller than item_lower!'
        self.load_all_data()
        if self.item_upper == -1:
            self.item_upper = self.train_counts_nd.shape[0] - 1
            print('set item_upper to the training data item upper', self.item_upper)

    # before generating task , this fun must be called at least once
    def set_device(self, device):
        super().set_device(device)
        self.flush_tensor()

    def shift_sample_function(self):
        if self.shift_step is None or self.next_sample_func is None:
            return
        else:
            if self.shift_step == self.curr_step:
                self.set_sample_func(self.next_sample_func)

    def get_item_upper_lower(self):
        return self.item_upper, self.item_lower

    def set_item_upper_lower(self, item_upper, item_lower):
        self.item_upper = item_upper
        self.item_lower = item_lower
        assert self.item_lower < self.item_upper or self.item_upper == -1, \
            'item upper should not be smaller than item_lower!'
        if self.item_upper == -1:
            self.item_upper = self.train_counts_nd.shape[0] - 1
            print('set item_upper to the training data item upper', self.item_upper)

    def flush_tensor(self):
        # convert all ndarray to tensor , and set current device to all tensor
        self.train_queries_tensor = torch.tensor(self.train_queries_nd, device=self.device).float()
        self.train_counts_tensor = torch.tensor(self.train_counts_nd, device=self.device).float()
        for i in range(len(self.test_task_file_queries_np_list)):
            self.test_task_file_counts_tensor_list.append(
                torch.tensor(self.test_task_file_counts_np_list[i], device=self.device).float())
            self.test_task_file_queries_tensor_list.append(
                torch.tensor(self.test_task_file_queries_np_list[i], device=self.device).float())
        for i in range(len(self.train_task_file_queries_np_list)):
            self.train_task_file_counts_tensor_list.append(
                torch.tensor(self.train_task_file_counts_np_list[i], device=self.device).float())
            self.train_task_file_queries_tensor_list.append(
                torch.tensor(self.train_task_file_queries_np_list[i], device=self.device).float())
        self.train_node_tensor = torch.tensor(self.train_node_np, device=self.device).float()
        self.fake_edge_tensor, self.fake_edge_index = generate_fake_edge_tensor(self.train_queries_nd,
                                                                                self.train_node_np, self.device,
                                                                                self.fake_edge_num_ratio)
        self.shuffle_train()

    def load_all_data(self):
        # load all train task file into list
        sum_of_train_data_counts = 0
        train_task_dir_path = os.path.join(self.data_path, self.train_dir_name)
        train_file_path_list = os.listdir(train_task_dir_path)
        for path in train_file_path_list:
            file = np.load(os.path.join(train_task_dir_path, path))
            self.train_task_file_queries_np_list.append(file['embeddings'])
            self.train_task_file_counts_np_list.append(file['counts'])
            sum_of_train_data_counts += file['counts'].sum()
        self.train_stream_length = sum_of_train_data_counts
        self.train_file_ave_length = sum_of_train_data_counts // len(self.train_task_file_counts_np_list)
        # load all test task file into list
        sum_of_test_data_counts = 0
        test_task_dir_path = os.path.join(self.data_path, self.test_dir_name)
        test_file_path_list = os.listdir(test_task_dir_path)
        for path in test_file_path_list:
            file = np.load(os.path.join(test_task_dir_path, path))
            self.test_task_file_queries_np_list.append(file['embeddings'])
            self.test_task_file_counts_np_list.append(file['counts'])
            sum_of_test_data_counts += file['counts'].sum()
        self.test_stream_length = sum_of_test_data_counts
        self.test_file_ave_length = sum_of_test_data_counts // len(self.test_task_file_counts_np_list)

        self.num_of_train_file = len(self.train_task_file_counts_np_list)
        self.num_of_test_file = len(self.test_task_file_counts_np_list)

        assert self.input_dimension == self.train_task_file_queries_np_list[0].shape[1]
        assert self.input_dimension == self.test_task_file_queries_np_list[0].shape[1]
        self.merge_all_train_data()

    def merge_all_train_data(self):
        self.train_queries_nd, self.train_counts_nd = merge_task_file(self.train_task_file_queries_np_list,
                                                                      self.train_task_file_counts_np_list)
        self.train_index = [i for i in range(self.train_counts_nd.shape[0])]
        self.train_node_np = extract_node(self.train_queries_nd)
        self.filtered_train_node_np = filter_node(self.train_node_np).astype(np.float32)

    def regenerate_fake_edge(self):
        if self.train_item_pos % (self.fake_edge_num_ratio * 2) == 1:
            print('generate new fake edge')
            self.fake_edge_tensor, self.fake_edge_index = generate_fake_edge_tensor(self.train_queries_nd,
                                                                                    self.train_node_np, self.device,
                                                                                    self.fake_edge_num_ratio)
        else:
            random.shuffle(self.fake_edge_index)
            self.fake_edge_tensor = self.fake_edge_tensor[self.fake_edge_index]
            print('reuse old fake edge')

    def shuffle_train(self):
        if self.generate_fake_edge:
            self.regenerate_fake_edge()
        self.train_item_pos = 0
        random.shuffle(self.train_index)
        self.train_queries_tensor = self.train_queries_tensor[self.train_index]
        self.train_counts_tensor = self.train_counts_tensor[self.train_index]

    def sample_train_support(self, item_size=None, skew_ratio=None, stream_length=None):
        assert skew_ratio is None
        self.curr_step += 1
        self.shift_sample_function()
        return self.sample_train_support_func(item_size=item_size, stream_length=stream_length)

    def set_sample_func(self, func_name):
        if func_name == 'itemsize':
            self.sample_train_support_func = self.sample_train_support_by_itemsize
            print('set sample func itemsize')
        elif func_name == 'stream_length':
            self.sample_train_support_func = self.sample_train_support_by_length
            print('set sample func stream_length')
        else:
            print('error , func_name have two choices: itemsize and stream_length')
            exit(-1)

    def sample_train_support_by_itemsize(self, item_size, stream_length):
        assert stream_length is None
        if item_size is None:
            item_size = int(random.random() * (self.item_upper - self.item_lower) + self.item_lower)
        if item_size + self.train_item_pos >= self.train_counts_tensor.shape[0]:
            self.shuffle_train()
            if item_size + self.train_item_pos > self.train_counts_tensor.shape[0]:
                print('train item_size should be smaller than the number of all items')
                exit(0)
        items = self.train_queries_tensor[self.train_item_pos:item_size + self.train_item_pos]
        frequencies = self.train_counts_tensor[self.train_item_pos:item_size + self.train_item_pos]
        self.train_item_pos += item_size
        if self.generate_fake_edge:
            info = self.fake_edge_tensor
        else:
            info = None
        return items.clone(), frequencies.clone(), info

    def sample_train_support_by_length(self, item_size, stream_length):
        assert item_size is None
        if stream_length is None:
            # self.curr_step > 100 for producing checkpointing task
            if self.total_step is not None and self.step_scale is not None and self.curr_step > 100:
                curr_stage = 1.5 * (self.curr_step * self.step_scale) / self.total_step
                if curr_stage > self.step_scale:
                    curr_stage = self.step_scale
                print('Incremental sampling strategy,curr stage:', curr_stage)

                ratio = random.random() * curr_stage
                if ratio >= 1:
                    ratio = 0.999
                stream_length = int(ratio * self.train_stream_length)
            else:
                stream_length = int((random.random() * 0.99 + 0.001) * self.train_stream_length)
        num = stream_length // self.train_file_ave_length
        if num < 1:
            num = 1
        start_pos_end = self.num_of_train_file - num
        random_ratio = self.rng.random()
        start_pos = int(random_ratio * start_pos_end)
        res_item_nd, res_counts_nd = merge_task_file(
            self.train_task_file_queries_np_list, self.train_task_file_counts_np_list, start_pos=start_pos,
            merge_num=num)
        res_item_tensor = torch.tensor(res_item_nd, device=self.device).float()
        res_counts_tensor = torch.tensor(res_counts_nd, device=self.device).float()
        info = None
        if self.generate_fake_edge:
            test_node_np = extract_node(res_item_nd)
            test_fake_edge_tensor, _ = generate_fake_edge_tensor(res_item_nd, test_node_np, self.device, 1.5)
            if test_fake_edge_tensor.shape[0] <= res_item_tensor.shape[0]:
                print('not enough')
                test_fake_edge_tensor, _ = generate_fake_edge_tensor(res_item_nd, test_node_np, self.device, 2)
            info = test_fake_edge_tensor

        return res_item_tensor, res_counts_tensor, info

    def sample_test_support_by_item_size(self, item_size=None, ):
        print('this func need to be implemented!')
        exit(-1)
        pass

    def sample_test_support_by_stream_length(self, stream_length, ):
        num = round(stream_length / self.test_file_ave_length)
        if num < 1:
            num = 1
        start_pos_end = self.num_of_test_file - num
        random_ratio = self.rng.random()
        start_pos = int(random_ratio * start_pos_end)
        res_item_nd, res_counts_nd = merge_task_file(
            self.test_task_file_queries_np_list, self.test_task_file_counts_np_list, start_pos=start_pos, merge_num=num)
        res_item_tensor = torch.tensor(res_item_nd, device=self.device).float()
        res_counts_tensor = torch.tensor(res_counts_nd, device=self.device).float()
        info = None

        if self.generate_fake_edge:
            test_node_np = extract_node(res_item_nd)
            test_fake_edge_tensor, _ = generate_fake_edge_tensor(res_item_nd, test_node_np, self.device, 1.5)
            if test_fake_edge_tensor.shape[0] <= res_item_tensor.shape[0]:
                print('not enough')
                test_fake_edge_tensor, _ = generate_fake_edge_tensor(res_item_nd, test_node_np, self.device, 2)
            info = test_fake_edge_tensor
        return res_item_tensor, res_counts_tensor, info

    # generate_test_support
    # item_size stream_length sream_length_ratio, only one in them != None
    def sample_test_support(self, item_size=None, stream_length=None, stream_length_ratio=None, skew_ratio=None):
        assert skew_ratio is None
        count_none = 0
        if item_size is None:
            count_none += 1
        if stream_length is None:
            count_none += 1
        if stream_length_ratio is None:
            count_none += 1

        assert count_none == 2

        if item_size is not None:
            return self.sample_test_support_by_item_size(item_size)
        else:
            if stream_length_ratio is not None:
                return self.sample_test_support_by_stream_length(self.test_stream_length * stream_length_ratio)
            else:
                return self.sample_test_support_by_stream_length(stream_length)


# simplified version for subimago
class BasicAnonymousSupportGenerator(SupportGeneratorInterface):
    def __init__(self, input_dim=10, item_lower=10, item_upper=10000, ):
        super().__init__()
        # index for shuffle
        self.data_index = None
        self.generate_ratio = 30
        # function for sample train support
        # store fake nodes‘ embeddings
        self.fake_node_nd = None
        self.fake_data_nd = None
        self.fake_node_tensor = None
        self.fake_data_tensor = None
        self.ave_frequency = 5
        self.input_dimension = input_dim
        self.item_pos = None
        self.item_lower = item_lower
        self.item_upper = item_upper
        self.rng = np.random.default_rng()
        assert self.item_lower < self.item_upper or self.item_upper == -1, \
            'item upper should not be smaller than item_lower!'
        self.load_all_data()

    # before generating task , this fun must be called at least once
    def set_device(self, device):
        super().set_device(device)
        self.flush_tensor()

    def get_item_upper_lower(self):
        return self.item_upper, self.item_lower

    def set_item_upper_lower(self, item_upper, item_lower):
        self.item_upper = item_upper
        self.item_lower = item_lower
        assert self.item_lower < self.item_upper, \
            'item upper should not be smaller than item_lower!'
        if self.item_upper >= self.generate_ratio * self.fake_node_nd.shape[0] * 2 / 3:
            print('the item_upper is a too big, have a high risk of  overflow', self.item_upper)
            exit()

    def flush_tensor(self):
        # convert all ndarray to tensor , and set current device to all tensor
        self.fake_node_tensor = torch.tensor(self.fake_node_nd, device=self.device).float()
        self.fake_data_tensor = torch.tensor(self.fake_data_nd, device=self.device).float()
        self.shuffle_data()

    def load_all_data(self):
        self.generate_fake_node()
        self.generate_fake_data()

    def generate_fake_node(self):
        encode_dim = self.input_dimension // 2
        max_node = (2 ** encode_dim) - 1
        node_id_nd = np.arange(max_node) + max_node + 1
        fake_node_nd_list = []
        assert node_id_nd.size == max_node, 'error about node_id_nd_size'
        for i in range(max_node):
            node_str = bin(int(node_id_nd[i])).replace('0b', '')[-encode_dim:]
            str_nd_node_str = ""
            for j in node_str:
                str_nd_node_str += j + ","
            fake_node_nd_list.append(np.fromstring(str_nd_node_str, sep=','))
        self.fake_node_nd = np.array(fake_node_nd_list)

    # ratio parameter is the approximate ratio of fake edges divide the fake node
    def generate_fake_data(self):
        repeat_fake_node_nd = np.tile(self.fake_node_nd, (self.generate_ratio, 1))
        index = [i for i in range(repeat_fake_node_nd.shape[0])]
        random.shuffle(index)
        source_fake_node = repeat_fake_node_nd[index]
        random.shuffle(index)
        destination_fake_node = repeat_fake_node_nd[index]
        fake_data = np.concatenate((source_fake_node, destination_fake_node), axis=1)
        fake_data_index = []
        # use bloomFilter to accelerate filtering repeated edge
        bf = BloomFilter(capacity=fake_data.shape[0], error_rate=0.01)
        for i in range(fake_data.shape[0]):
            if not bf.add(fake_data[i].tobytes()):
                fake_data_index.append(i)
        self.fake_data_nd = fake_data[fake_data_index]
        self.data_index = [i for i in range(self.fake_data_nd.shape[0])]

    def shuffle_data(self):
        self.item_pos = 0
        random.shuffle(self.data_index)
        self.fake_data_tensor = self.fake_data_tensor[self.data_index]

    def sample_train_support(self, item_size=None, skew_ratio=None, stream_length=None):
        assert skew_ratio is None
        assert stream_length is None
        if item_size is None:
            item_size = int(random.random() * (self.item_upper - self.item_lower) + self.item_lower)
        items = self.sample_items_data_by_item_size(item_size=item_size)
        frequencies = torch.ones(items.shape[0], device=items.device).float() * self.ave_frequency
        return items, frequencies, None

    def sample_items_data_by_item_size(self, item_size):
        if item_size + self.item_pos >= self.fake_data_tensor.shape[0]:
            self.shuffle_data()
            if item_size + self.item_pos > self.fake_data_tensor.shape[0]:
                print('train item_size should be smaller than the number of all items')
                exit(0)
        items = self.fake_data_tensor[self.item_pos:item_size + self.item_pos]
        self.item_pos += item_size
        return items

    # generate_test_support,basic MGS only sample test tasks by item size
    def sample_test_support(self, item_size=None, stream_length=None, stream_length_ratio=None, skew_ratio=None):
        assert skew_ratio is None
        assert stream_length is None
        assert stream_length_ratio is None
        assert item_size is not None
        items = self.sample_items_data_by_item_size(item_size=item_size)
        frequencies = torch.ones(items.shape[0], device=items.device).float() * self.ave_frequency
        return items, frequencies, None
