import time
from collections import defaultdict

from .Thread_Enqueuer import *
# from .Process_Enqueuer import *


class Prefetcher(object):
    """
    This helper class enables sample enqueuing and batch dequeuing, to speed
    up batch fetching. It abstracts away the enqueuing and dequeuing logic.
    """

    def __init__(self, get_sample, dataset_size, batch_size, final_batch=True, num_threads=1, prefetch_size=400):
        """
        Args:
          get_sample: a function that takes a pointer (index) and returns a sample
          dataset_size: total number of samples in the dataset
          final_batch: True or False, whether to keep or drop the final incomplete
            batch
          num_threads: num of parallel threads, >= 1
          prefetch_size: the maximum size of the queue. Set to some positive integer
            to save memory, otherwise, set to 0.
        """
        self.full_dataset_size = dataset_size
        self.final_batch = final_batch
        final_sz = self.full_dataset_size % batch_size
        if not final_batch:
            dataset_size = self.full_dataset_size - final_sz
        self.dataset_size = dataset_size
        self.batch_size = batch_size
        self.tmp_samples = defaultdict(list)
        
        self.enqueuer = Enqueuer(get_element=get_sample, num_elements=dataset_size,
                                 num_threads=num_threads, queue_size=prefetch_size)
        
        '''
        self.enqueuer = Enqueuer(get_element=get_sample, num_elements=dataset_size,
                                 num_processes=num_threads, queue_size=prefetch_size)
        '''
        # The pointer indicating whether an epoch has been fetched from the queue
        self.step_id = 0
        self.ep_done = True

    def batch_init(self):
        # Start enqueuing and other preparation at the beginning of an epoch.
        if self.ep_done:
            self.start_ep_prefetching()
        # Whether an epoch is done.
        self.ep_done = False

    def get_next_train_batch(self, **kwargs):
        batch_samples_indx = []
        samples = []
        repeat_num = 0
        # used_label_index is depend on the result returned by def get_sample() in TrainSet.py
        # ensure the unique label in per batch
        used_label_index = kwargs['used_label_index']
        self.batch_init()
        while len(samples) != self.batch_size:
            # if tmp samples have no element, get from queue
            if not self.enqueuer.queue.empty():
                tmp_time = time.time()
                _sample = self.enqueuer.queue.get_nowait()
                # print('get time:' + str(time.time() - tmp_time))
                sample = self.tmp_samples[_sample[used_label_index][0]].append(
                    _sample) if _sample[used_label_index][0] in batch_samples_indx else _sample
            else:
                sample = None if len(self.tmp_samples) == 0 else self.get_from_tmp_samples(
                    batch_samples_indx)
            # To deal with the sample, if not none, append in samples, or queue is empty, just start queue
            if sample is not None:
                sample_label = sample[used_label_index][0]
                batch_samples_indx.append(sample_label)
                samples.append(sample)
                self.step_id += 1
            elif self.enqueuer.queue.empty():
                self.enqueuer.start_ep()
        self.ep_done = True if self.step_id >= self.dataset_size else False
        _ = self.ep_done_operation() if self.ep_done else None
        return samples, self.ep_done

    def get_next_test_batch(self, **kwargs):
        self.batch_init()
        samples = []
        # Indeed, `>` will not occur.
        while len(samples) < self.batch_size and self.step_id < self.dataset_size:
            self.step_id += 1
            sample = self.enqueuer.queue.get()
            samples.append(sample)
        # Indeed, `>` will not occur.
        self.ep_done = True if self.step_id >= self.dataset_size else False
        _ = self.ep_done_operation() if self.ep_done else None
        return samples, self.ep_done

    def ep_done_operation(self):
        self.enqueuer.end_ep()
        self.enqueuer.clear_queue()

    def get_from_tmp_samples(self, this_batch_samples):
        sample = None
        all_tmp_labels = list(self.tmp_samples.keys())
        key = [all_tmp_labels[0]] if len(this_batch_samples) == 0 else [
            key_tmp for key_tmp in all_tmp_labels if key_tmp not in this_batch_samples]
        if len(key) >= 1:
            key = key[0]
            sample = self.tmp_samples[key][0]
            del(self.tmp_samples[key][0])
            if len(self.tmp_samples[key]) == 0:
                del(self.tmp_samples[key])
        return sample

    def start_ep_prefetching(self):
        """
        NOTE: Has to be called at the start of every epoch.
        """
        self.tmp_samples.clear()
        try:
            self.enqueuer.end_ep()
        except:
            pass
        self.enqueuer.start_ep()
        self.step_id = 0

    def stop(self):
        """This can be called to stop threads, e.g. after finishing using the
            dataset, or when existing the python main program.
        """
        self.enqueuer.stop()
