from .message_queue import MessageQueue
from concurrent.futures import ThreadPoolExecutor
import asyncio
import types
import time


class AsyncQueue(MessageQueue):
    
    def __init__(self, concurrent):
        super(AsyncQueue, self).__init__(concurrent)

        self.eval_queue = queue = asyncio.Queue(maxsize=self.concurrent)
        self.event = asyncio.Event()
        
        self.count = 0

    async def put_message(self, message_generator):
        
        if isinstance(message_generator, types.GeneratorType):
            try:
                while True:
                    await self.eval_queue.put(next(message_generator))
                    await self.event.wait()
                    self.event.clear()
                    self.count += 1
            except StopIteration:
                self.logger.info("Message Reader Down...")
        else:
            for message in message_generator:
                await self.eval_queue.put(message)
                await self.event.wait()
                self.event.clear()
                self.count += 1
        
        for _ in range(self.concurrent):
            await self.eval_queue.put(None)
    
    async def handle_message(self, executor, idx):
        
        while True:
            start_time = time.time()
            self.logger.info("Runner AsyncQueue ID [{}]-> Start...".format(idx))
            
            message = await self.eval_queue.get()
            if message is None:
                self.eval_queue.task_done()
                break

            # print(message["session_id"])
            try:
                executor.run(message)
            except:
                pass
            
            self.eval_queue.task_done()
            self.event.set()
            
            self.logger.info(
                "Runner Queue ID [{}]-> Spent {} s.".format(
                    idx,
                    round(time.time() - start_time, 3),
                )
            )
    
    def run(self, executor, message_generator):
        return asyncio.run(self._exec(executor, message_generator))
    
    async def _exec(self, executor, message_generator):
        
        self.event.set()
        producer = asyncio.create_task(self.put_message(message_generator))
        consumers = [asyncio.create_task(self.handle_message(executor, idx)) for idx in range(self.concurrent)]
        await producer
        await self.eval_queue.join()
        await asyncio.gather(*consumers)
        
        self.logger.info("AsyncQueue End. Length={}".format(self.count))

    def run_in_executor(self, in_funcs, func_args):
        return asyncio.run(self._in_exec(in_funcs, func_args))

    async def _in_exec(self, in_funcs, func_args):
        
        loop = asyncio.get_running_loop()
        with ThreadPoolExecutor(max_workers=5) as executor:
            tasks = [loop.run_in_executor(executor, in_func, *func_arg) for in_func, func_arg in zip(in_funcs, func_args)]
            results = await asyncio.gather(*tasks)
        return results
