from .message_queue import MessageQueue
from multiprocessing import Value, Lock
import queue
import threading
import types
import time


class ThreadingQueue(MessageQueue):
    
    def __init__(self, concurrent):
        super(ThreadingQueue, self).__init__(concurrent)
        
        self.eval_queue = queue.Queue()
        self.result_queue = queue.Queue()
        self.eval_generator = None
        
        self.concurrent = concurrent
        
        self.count = Value('i', 0)
        self.error = Value('i', 0)
        self.lock = Lock()
        
    def put_message(self, message_generator):
        
        if isinstance(message_generator, types.GeneratorType):
            self.eval_generator = message_generator
        else:
            for message in message_generator:
                self.eval_queue.put(message)
    
    def run(self, executor, message_generator):
        
        def async_func(idx, executor):
            
            while True:
                
                start_time = time.time()
                self.logger.info("Runner ThreadingQueue ID [{}]-> Start...".format(idx))

                if self.eval_generator is None and self.eval_queue.empty():
                    break

                with self.lock:
                    try:
                        if self.eval_generator is not None:
                                message = next(self.eval_generator)
                        else:
                            message = self.eval_queue.get()
                        
                        self.count.value += 1
                    except StopIteration:
                        break
                
                # print(message["session_id"])
                executor.run(message)
                # try:
                #     executor.run(message)
                # except:
                #     with self.lock:
                #         self.error.value += 1
                
                with self.lock:
                    self.logger.info(
                        "Runner Queue ID [{} / {}]-> Spent {} s.".format(
                            idx,
                            self.count.value,
                            round(time.time() - start_time, 3),
                        )
                    )
        
        self.put_message(message_generator)
        
        thread_pool = []
        for idx in range(self.concurrent):
            consumer_thread = threading.Thread(
                target=async_func, args=(idx, executor)
            )
            consumer_thread.start()
            thread_pool.append(consumer_thread)

        for consumer_thread in thread_pool:
            consumer_thread.join()
        
        with self.lock:
            self.logger.info("Runner Queue End. Length={} Error={}".format(self.count.value, self.error.value))

    def run_noparam(self, executor):
        
        def async_func(idx, executor):
            
            while True:
                
                start_time = time.time()
                self.logger.info("Runner ThreadingQueue No Param ID [{}]-> Start...".format(idx))

                try:
                    executor()
                except:
                    with self.lock:
                        self.error.value += 1
                
                with self.lock:
                    self.logger.info(
                        "Runner ThreadingQueue No Param ID [{} / {}]-> Spent {} s.".format(
                            idx,
                            self.count.value,
                            round(time.time() - start_time, 3),
                        )
                    )
        
        thread_pool = []
        for idx in range(self.concurrent):
            consumer_thread = threading.Thread(
                target=async_func, args=(idx, executor)
            )
            consumer_thread.start()
            thread_pool.append(consumer_thread)

        for consumer_thread in thread_pool:
            consumer_thread.join()
        
        with self.lock:
            self.logger.info("Runner ThreadingQueue No Param End. Length={} Error={}".format(self.count.value, self.error.value))

    def run_param(self, executor, message_generator):
        
        def async_func(idx, executor):
            
            while True:
                
                start_time = time.time()
                self.logger.info("Runner ThreadingQueue ID [{}]-> Start...".format(idx))

                if self.eval_generator is None and self.eval_queue.empty():
                    break

                with self.lock:
                    try:
                        if self.eval_generator is not None:
                                message = next(self.eval_generator)
                        else:
                            message = self.eval_queue.get()
                        
                        self.count.value += 1
                    except StopIteration:
                        break
                
                executor(message)
                # try:
                #     return executor(message)
                # except:
                #     with self.lock:
                #         self.error.value += 1
                
                with self.lock:
                    self.logger.info(
                        "Runner Queue ID [{} / {}]-> Spent {} s.".format(
                            idx,
                            self.count.value,
                            round(time.time() - start_time, 3),
                        )
                    )
        
        self.put_message(message_generator)
        
        thread_pool = []
        for idx in range(self.concurrent):
            consumer_thread = threading.Thread(
                target=async_func, args=(idx, executor)
            )
            consumer_thread.start()
            thread_pool.append(consumer_thread)

        for consumer_thread in thread_pool:
            consumer_thread.join()
        
        with self.lock:
            self.logger.info("Runner Queue End. Length={} Error={}".format(self.count.value, self.error.value))


    def run_in_executor(self, executor, message_generator):
        
        def async_func(idx, executor):
            
            while True:
                
                start_time = time.time()
                self.logger.info("Runner Inner Func [{}]-> Start...".format(idx))

                if self.eval_generator is None and self.eval_queue.empty():
                    break

                with self.lock:
                    try:
                        if self.eval_generator is not None:
                                message = next(self.eval_generator)
                        else:
                            message = self.eval_queue.get()
                        
                        self.count.value += 1
                    except StopIteration:
                        break
                
                # executor(message)
                # import pdb; pdb.set_trace()
                result = executor(*message)
                self.result_queue.put(result)

                # try:
                #     result = executor(message)
                #     self.result_queue.append(result)
                # except:
                #     with self.lock:
                #         self.error.value += 1
                
                with self.lock:
                    self.logger.info(
                        "Runner Queue ID [{} / {}]-> Spent {} s.".format(
                            idx,
                            self.count.value,
                            round(time.time() - start_time, 3),
                        )
                    )
        
        self.put_message(message_generator)
        
        thread_pool = []
        # import pdb; pdb.set_trace()
        for idx, _executor in enumerate(executor):
            consumer_thread = threading.Thread(
                target=async_func, args=(idx, _executor)
            )
            consumer_thread.start()
            thread_pool.append(consumer_thread)

        for consumer_thread in thread_pool:
            consumer_thread.join()
        
        self.result_queue

        with self.lock:
            self.logger.info("Runner Inner Func End. Length={} Error={}".format(self.count.value, self.error.value))

        results = []
        while not self.result_queue.empty():
            results.append(self.result_queue.get())
        return results
