import sys
import logging
import random
import numpy as np
import torch
import queue
import threading

import torch.multiprocessing as mp
from torch.utils.data import get_worker_info


class Logger(logging.Logger):
    def log_dict(self, msg: dict, level=logging.INFO):
        self.log(level, ", ".join(f"{k}: {v}" for k, v in msg.items()))


def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True):
    logger = Logger(name)
    logger.setLevel(loglevel)
    while len([logger.removeHandler(i) for i in logger.handlers]):
        pass  # Remove all handlers (only useful when debugging)
    formatter = logging.Formatter(fmt="%(message)s")
    handlers = []
    if logfile is not None:
        handlers.append(logging.FileHandler(logfile, mode="a"))
    if streamHandle:
        handlers.append(logging.StreamHandler(stream=sys.stdout))

    for handler in handlers:
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    return logger


_main_process_seed = [0]

def set_main_process_seed(seed):
    set_seed(seed)
    _main_process_seed[0] = seed


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class ModelPlaceholder:
    def __init__(self, in_queue, out_queue):
        self.in_queue = in_queue 
        self.out_queue = out_queue

    def __call__(self, batch):
        self.in_queue.put(batch)
        return self.out_queue.get()


class ModelProxy:
    def __init__(self, model, num_workers):
        self.model = model
        self.num_workers = num_workers
        if num_workers > 1:
            self.in_queues = [mp.Queue() for _ in range(num_workers)]
            self.out_queues = [mp.Queue() for _ in range(num_workers)]
            self.placeholder = ModelPlaceholder(self.in_queues[0], self.out_queues[0])
            self.stop = threading.Event()
            self.thread = threading.Thread(target=self.run, daemon=True)
            self.thread.start()
        else:
            self.placeholder = model
            
    def worker_init_fn(self, worker_id):
        if self.num_workers > 1:
            worker_info = get_worker_info()
            model = worker_info.dataset.model
            model.in_queue = self.in_queues[worker_info.id]
            model.out_queue = self.out_queues[worker_info.id]
            set_seed(_main_process_seed[0] + worker_info.id + 42)
        else:
            pass

    def run(self):
        while not self.stop.is_set():
            for in_queue, out_queue in zip(self.in_queues, self.out_queues):
                try:
                    batch = in_queue.get(True, 1e-5)
                except queue.Empty:
                    continue
                    
                with torch.no_grad():
                    fwd_dist, bck_dist, values = self.model(batch.to(self.model.device))
                    output = (
                        fwd_dist.to('cpu'), 
                        bck_dist.to('cpu') if bck_dist is not None else None,
                        values.to('cpu') if values is not None else None
                        )
                    
                out_queue.put(output)
            
    def stop(self):
        self.stop.set()