import os
import time
import ctypes
import subprocess
import threading
import multiprocessing as mp

import numpy as np

from prover.utils import AttrDict


class TaskQueue(object):
    def __init__(self, batch_size=512, name='test'):
        self.name = name
        self.batch_size = batch_size
        self.manager = mp.Manager()
        self.waiting_list = self.manager.list()
        self.all_tasks_done = mp.Event()
        self.lock = mp.Lock()

        self._monitor_log = self.manager.list()
        self._monitor_thread = threading.Thread(target=self._monitor)
        self._monitor_thread.start()
    
    def _monitor(self):
        last_log_time = time.time()
        while not self.all_tasks_done.is_set():
            if time.time() - last_log_time >= 60.0:
                with self.lock:
                    if len(self._monitor_log) > 0:
                        print('TaskQueue-{}:  {} requests popped with avg batch_size {:.1f} in last period  {} waiting in queue'.format(
                            self.name, np.sum(self._monitor_log), np.mean(self._monitor_log), len(self.waiting_list),
                        ))
                        self._monitor_log[:] = []
                last_log_time = time.time()
            time.sleep(1.0)
    
    def __len__(self):
        return len(self.waiting_list)
    
    def put(self, item):
        with self.lock:
            self.waiting_list.append(item)
    
    def get(self, no_wait=False):
        while not self.all_tasks_done.is_set():
            with self.lock:
                if len(self.waiting_list) > 0:
                    tasks = self.waiting_list[:self.batch_size]
                    self.waiting_list[:self.batch_size] = []
                    self._monitor_log.append(len(tasks))
                    return tasks
            if no_wait:
                break
            time.sleep(0.1)
        return None
    
    def close(self):
        self.all_tasks_done.set()
        self._monitor_thread.join()


class ProcessScheduler(object):
    def __init__(self, batch_size=512, name='test'):
        self.name = name
        self.manager = mp.Manager()
        self.batch_size = batch_size
        self.task_queue = TaskQueue(batch_size=batch_size, name=name)
        self.request_statuses = self.manager.dict()
        self.request_counter = mp.Value(ctypes.c_int32, 0)
        self.lock = mp.Lock()

    def submit_request(self, data):
        with self.lock:
            self.request_counter.value += 1
            request_id = self.request_counter.value
            self.request_statuses[request_id] = None
            self.task_queue.put((time.time(), request_id, data))
        return request_id
    
    def submit_all_request(self, data_list):
        request_id_list = [self.submit_request(data) for data in data_list]
        return request_id_list

    def get_request_status(self, request_id):
        with self.lock:
            response = self.request_statuses.get(request_id, None)
            if response is not None:
                self.request_statuses.pop(request_id)
            return response
    
    def get_request_outputs(self, request_id):
        while True:
            outputs = self.get_request_status(request_id)
            if outputs is not None:
                return outputs
            time.sleep(1.0)
    
    def get_all_request_outputs(self, request_id_list):
        outputs_list = []
        for request_id in request_id_list:
            outputs_list.append(self.get_request_outputs(request_id))
        return outputs_list
    
    def close(self):
        self.task_queue.close()


class Scheduler(object):
    def __init__(self, scheduler_dict):
        self._scheduler_dict = scheduler_dict
        for name, scheduler in scheduler_dict.items():
            self.__setattr__(name, scheduler)
            for key in dir(scheduler):
                if not key.startswith('_'):
                    self.__setattr__(f'{name}_{key}', scheduler.__getattribute__(key))
    
    def close(self):
        for _, scheduler in self._scheduler_dict.items():
            scheduler.close()
