import torch

from AbstractClass.TaskRelatedClasses import AbstractTaskProducer, AbstractMetaTask, SupportGeneratorInterface, \
    AbstractQueryGenerator
from SourceCode.TaskRelatedClasses.TaskData import MetaTask, SupportSet, QuerySet


class TaskProducer(AbstractTaskProducer):
    def __init__(self, support_generator: SupportGeneratorInterface, query_generator: AbstractQueryGenerator
                 , device, test_task_zipf_param_list=None, test_task_item_size_list=None, test_task_length_list=None,
                 test_task_length_ratio_list=None, test_task_group_size=10):
        super(TaskProducer, self).__init__(support_generator, query_generator)
        self.support_generator = support_generator
        self.query_generator = query_generator
        self.device = device
        self.test_task_group_size = test_task_group_size
        if test_task_length_ratio_list is not None:
            base_support_generator = self.support_generator
            while not hasattr(base_support_generator, "test_stream_length"):
                base_support_generator = base_support_generator.base_support_generator
            self.test_task_length_list = test_task_length_ratio_list * base_support_generator.test_stream_length
        else:
            self.test_task_length_list = test_task_length_list
        self.test_task_zipf_param_list = test_task_zipf_param_list
        self.test_task_item_size_list = test_task_item_size_list
        assert self.test_task_item_size_list is None or self.test_task_length_list is None

    def produce_one_task(self, q, pass_cuda_tensor):
        support_x, support_y, _ = self.support_generator.sample_train_support()
        query_x, query_y = self.query_generator.generate_train_query(support_x, support_y, _)
        if not pass_cuda_tensor and support_y.device != torch.device("cpu"):
            support_x = support_x.cpu()
            support_y = support_y.cpu()
            query_x = query_x.cpu()
            query_y = query_y.cpu()
        q.put(support_x)
        q.put(support_y)
        q.put(query_x)
        q.put(query_y)

    def produce_train_task(self, q, pass_cuda_tensor, pre_item_lower_list=None, pre_item_upper_list=None,
                           pre_step_list=None
                           , checkpointing_task_num=None):
        self.support_generator.flush_tensor()
        if pre_step_list is not None:
            # save context
            item_upper, item_lower = self.support_generator.get_item_upper_lower()
            assert pre_item_upper_list is not None and pre_item_lower_list is not None and checkpointing_task_num is not None
            assert len(pre_item_upper_list) == len(pre_item_lower_list) == len(pre_step_list), 'length not equal'
            for i in range(len(pre_step_list)):
                pre_step = pre_step_list[i]
                pre_item_upper = pre_item_upper_list[i]
                pre_item_lower = pre_item_lower_list[i]
                print('set item size section according to train step ', pre_item_upper, pre_item_lower)
                self.support_generator.set_item_upper_lower(pre_item_upper, pre_item_lower)
                for j in range(pre_step + checkpointing_task_num):
                    self.produce_one_task(q, pass_cuda_tensor)
            # reload context
            self.support_generator.set_item_upper_lower(item_upper, item_lower)
        while True:
            self.produce_one_task(q, pass_cuda_tensor)
            # print('produce_task...')

    def produce_test_task(self):
        print('start producing task...')
        meta_task_group_discribe_list = []
        test_meta_task_group_list = []
        num_of_group = None
        if self.test_task_length_list is not None:
            num_of_group = len(self.test_task_length_list)
        elif self.test_task_item_size_list is not None:
            num_of_group = len(self.test_task_item_size_list)
        assert num_of_group is not None, 'error , num of group'
        # start generating test tasks group one by one
        for i in range(num_of_group):
            # according to stream length
            if self.test_task_length_list is not None:
                test_stream_length = self.test_task_length_list[i]
                print('producing ', test_stream_length, "stream length test meta tasks ....")
                task_list, item_description = self.produce_one_group_test_tasks(test_stream_length=test_stream_length)
                meta_task_group_discribe_list.append('_' + str(test_stream_length) + '_length_' + item_description)
                test_meta_task_group_list.append(task_list)
            # according to item size
            elif self.test_task_item_size_list is not None:
                test_item_size = self.test_task_item_size_list[i]
                # no zipf require
                if self.test_task_zipf_param_list is None:
                    print('producing ', test_item_size, "item size test meta tasks ....")
                    task_list, item_description = self.produce_one_group_test_tasks(test_item_size=test_item_size)
                    test_meta_task_group_list.append(task_list)
                    meta_task_group_discribe_list.append('_' + str(test_item_size) + '_item_')
                # one zip_param one group
                else:
                    for zip_param in self.test_task_zipf_param_list:
                        print('producing ', test_item_size, "item size", zip_param, "zipf_param test meta tasks ....")
                        task_list, item_description = self.produce_one_group_test_tasks(test_item_size=test_item_size,
                                                                                        zipf_param=zip_param)
                        test_meta_task_group_list.append(task_list)
                        meta_task_group_discribe_list.append(
                            '_' + str(test_item_size) + '_item_' + str(zip_param) + '_zipf_')
            else:
                print('error branch!')
                exit(-1)
        print('producing test task done...')
        return test_meta_task_group_list, meta_task_group_discribe_list

    def produce_one_group_test_tasks(self, test_stream_length=None, test_item_size=None, zipf_param=None):

        assert test_stream_length is None or test_item_size is None
        task_list = []
        item_sum = 0
        for i in range(self.test_task_group_size):
            if zipf_param is not None:
                self.support_generator.set_once_zip_param(zipf_param)
            test_support_x, test_support_y, _ = self.support_generator.sample_test_support(
                stream_length=test_stream_length, item_size=test_item_size)
            test_query_x, test_query_y = self.query_generator.generate_test_query(test_support_x, test_support_y, _)
            item_sum += test_support_y.shape[0]
            test_support_set = SupportSet(test_support_x.cpu(),
                                          test_support_y.cpu(), self.device)
            test_query_set = QuerySet(test_query_x.cpu(),
                                      test_query_y.cpu(), self.device)
            test_meta_task = MetaTask(test_support_set, test_query_set)
            # test_meta_task.to_device()
            task_list.append(test_meta_task)
        item_description = str(item_sum // self.test_task_group_size) + '_item_'
        return task_list, item_description
