import threading
from queue import Queue
import time
from collections import defaultdict

class Counter(object):
    """A thread safe counter."""

    def __init__(self, val=0, max_val=0):
        self._value = val
        self.max_value = max_val
        self._lock = threading.Lock()

    def reset(self):
        with self._lock:
            self._value = 0

    def set_max_value(self, max_val):
        self.max_value = max_val

    def increment(self):
        with self._lock:
            if self._value < self.max_value:
                self._value += 1
                incremented = True
            else:
                incremented = False
            return incremented, self._value

    def get_value(self):
        with self._lock:
            return self._value


class Enqueuer(object):
    def __init__(self, get_element, num_elements, num_threads=1, queue_size=20):
        """
        Args:
          get_element: a function that takes a pointer and returns an element
          num_elements: total number of elements to put into the queue
          num_threads: num of parallel threads, >= 1
          queue_size: the maximum size of the queue. Set to some positive integer
            to save memory, otherwise, set to 0.
        """
        self.get_element = get_element
        assert num_threads > 0
        self.num_threads = num_threads
        self.queue_size = queue_size
        self.queue = Queue(maxsize=queue_size)
        # The pointer shared by threads.
        self.ptr = Counter(max_val=num_elements)
        # The event to wake up threads, it's set at the beginning of an epoch.
        # It's cleared after an epoch is enqueued or when the states are reset.
        self.event = threading.Event()
        # To reset states.
        self.reset_event = threading.Event()
        # The event to terminate the threads.
        self.stop_event = threading.Event()
        self.threads = []
        for _ in range(num_threads):
            thread = threading.Thread(target=self.enqueue)
            # Set the thread in daemon mode, so that the main program ends normally.
            thread.daemon = True
            thread.start()
            self.threads.append(thread)

    def clear_queue(self):
        """Clearing the queue to get a new start"""
        while not self.queue.empty():
            _ = self.queue.get()

    def start_ep(self):
        """Start enqueuing an epoch."""
        self.event.set()

    def end_ep(self):
        """When all elements are enqueued, let threads sleep to save resources."""
        self.event.clear()
        self.ptr.reset()

    def reset(self):
        """Reset the threads, pointer and the queue to initial states. In common
        case, this will not be called."""
        self.reset_event.set()
        self.event.clear()
        # wait for threads to pause. This is not an absolutely safe way. The safer
        # way is to check some flag inside a thread, not implemented yet.
        time.sleep(5)
        self.reset_event.clear()
        self.ptr.reset()
        self.queue = Queue(maxsize=self.queue_size)

    def set_num_elements(self, num_elements):
        """Reset the max number of elements."""
        self.reset()
        self.ptr.set_max_value(num_elements)

    def stop(self):
        """Wait for threads to terminate."""
        self.stop_event.set()
        for thread in self.threads:
            thread.join()

    def enqueue(self):
        """
            Append id(4 imgs per id, including: ims, im_names, labels, mirrored) to queue. Default 200 ids.
        :return:
        """
        while not self.stop_event.isSet():
            # If the enqueuing event is not set, the thread just waits.
            if not self.event.wait(0.5):
                continue
            # Increment the counter to claim that this element has been enqueued by
            # this thread.
            incremented, ptr = self.ptr.increment()
            if incremented:
                element = self.get_element(ptr - 1)
                # When enqueuing, keep an eye on the stop and reset signal.
                while not self.stop_event.isSet() and not self.reset_event.isSet():
                    try:
                        # This operation will wait at most `timeout` for a free slot in
                        # the queue to be available.
                        self.queue.put(element, timeout=0.5)
                        break
                    except:
                        pass
            else:
                self.end_ep()
        print('Exiting thread {}!!!!!!!!'.format(
            threading.current_thread().name))