import copy
import random
import time
import pandas as pd
import numpy as np
from AbstractClass.TaskRelatedClasses import AbstractTaskConsumer, AbstractMetaTask
from SourceCode.TaskRelatedClasses.TaskData import MetaTask, SupportSet, QuerySet
from pybloom_live import BloomFilter
from torch.multiprocessing import Manager, Queue, Process


def check(meta_task):
    try:
        unique_item = meta_task.support_set.support_x.shape[0]
        support_stream_length = meta_task.support_set.support_y.sum()
        query_stream_length = meta_task.query_set.query_y.sum()

        exist_num = meta_task.query_set.query_y.sum().item()
        assert abs(exist_num - unique_item) < 1 or abs(support_stream_length - query_stream_length) < 0.5, \
            "check task error in TaskConsumer"
    except:
        print('check task error in TaskConsumer')


class TaskConsumer(AbstractTaskConsumer):
    def __init__(self, device):
        self.device = device

    def consume_train_task(self, q, pass_cuda_tensor):
        support_x_tensor = q.get()
        support_y_tensor = q.get()
        query_x_tensor = q.get()
        query_y_tensor = q.get()

        meta_task = MetaTask(SupportSet(support_x_tensor, support_y_tensor, self.device),
                             QuerySet(query_x_tensor, query_y_tensor, self.device))
        if not pass_cuda_tensor:
            meta_task.to_device()
        check(meta_task)
        return meta_task

    # release shared memory
    def del_meta_task(self, meta_task):
        del meta_task.support_set.support_x
        del meta_task.support_set.support_y
        del meta_task.query_set.query_y
        del meta_task.query_set.query_x
