import copy
import queue
import os
import time
import torch
import SMOS
import SMOS_utils
import zmq
import numpy as np
from collections import defaultdict
from multiprocessing.managers import BaseManager
from SMOS_utils import RWLock

from core.storage_config import StorageConfig

class BatchControlPacket:
    def __init__(self, worker_node_id, request_id, data=dict()):
        self.worker_node_id = worker_node_id
        self.request_id = request_id
        self.data = data

class BatchControl(object):
    """ Batch control object to control priority staleness and reanalyze staleness for each batch.
    """
    def __init__(self, worker_node_id, config, storage_config):
        self.worker_node_id = worker_node_id
        self.is_trainer = (worker_node_id < 0)

        self.priority_gap = config.priority_gap
        self.reanalyze_staleness = config.reanalyze_staleness
        self.target_model_interval = config.target_model_interval
        multi = int(config.world_size * 1.3)
        train_multi = config.world_size

        if self.is_trainer:
            # batch control: priority version & model version
            self.batch_control_meta = []
            self.train_batch_control_meta = []
            last_target_checkpoint = 0
            for i in range(config.training_steps):
                if i % self.target_model_interval == 0:
                    last_target_checkpoint = i
                priority_version = max(last_target_checkpoint - self.priority_gap * self.target_model_interval, 0)
                priority_version = min(priority_version, config.training_steps)
                if i % self.target_model_interval < self.reanalyze_staleness:
                    priority_version = max(priority_version - self.target_model_interval, 0)
                target_model_version = last_target_checkpoint - self.target_model_interval if i % self.target_model_interval < self.reanalyze_staleness else last_target_checkpoint
                target_model_version = max(target_model_version, 0)
                target_model_version = min(target_model_version, config.training_steps)

                batch_control_meta = dict(priority_version=priority_version, target_model_version=target_model_version)
                for _ in range(multi):
                    self.batch_control_meta.append(batch_control_meta)
                for _ in range(train_multi):
                    self.train_batch_control_meta.append(batch_control_meta)
            self.total_batch_controls = len(self.batch_control_meta)
        else:
            zmq_context = zmq.Context()
            
            # worker -> trainer
            self.__outbound_socket = zmq_context.socket(zmq.PUSH)
            outbound_port = storage_config.zmq_batch_control_worker_outbound_port
            outbound_address = f"tcp://{storage_config.trainer_address}:{outbound_port}"
            self.__outbound_socket.set_hwm(2)
            self.__outbound_socket.connect(outbound_address)

            # trainer -> worker
            self.__inbound_socket = zmq_context.socket(zmq.SUB)
            inbound_port = storage_config.zmq_batch_control_worker_inbound_port
            inbound_address = f"tcp://{storage_config.trainer_address}:{inbound_port}"
            self.__inbound_socket.setsockopt_string(zmq.SUBSCRIBE, "")
            self.__inbound_socket.connect(inbound_address)
        
        self.lock = RWLock()

    def gen_batch_control(self):
        if self.is_trainer:
            if len(self.batch_control_meta) == 0:
                return dict(priority_version=-1, target_model_version=-1)
            return self.batch_control_meta.pop(0)
        else:
            self.lock.writer_enter()
            start_time = time.time()
            request_id = np.random.randint(0, 123456789 + 1)
            self.__outbound_socket.send(SMOS_utils.serialize(BatchControlPacket(worker_node_id=self.worker_node_id, request_id=request_id)))

            response_packet = SMOS_utils.deserialize(self.__inbound_socket.recv())
            while response_packet.worker_node_id != self.worker_node_id or response_packet.request_id != request_id:
                current_time = time.time()
                if current_time - start_time > 1.:
                    raise RuntimeError("[Batch Control] Batch control client waits for 1 sec!!")
                response_packet = SMOS_utils.deserialize(self.__inbound_socket.recv())

            self.lock.writer_leave()
            return response_packet.data
    
    def match(self, bc_meta):
        self.lock.writer_enter()
        assert self.is_trainer
        if len(self.train_batch_control_meta) > 0 and bc_meta["priority_version"] == self.train_batch_control_meta[0]["priority_version"] and bc_meta["target_model_version"] == self.train_batch_control_meta[0]["target_model_version"]:
            self.train_batch_control_meta.pop(0)
            self.lock.writer_leave()
            return True, "Matched"
        if len(self.train_batch_control_meta) > 0 and bc_meta["priority_version"] <= self.train_batch_control_meta[0]["priority_version"] and bc_meta["target_model_version"] <= self.train_batch_control_meta[0]["target_model_version"]:
            self.lock.writer_leave()
            return False, "Drop"
        self.lock.writer_leave()
        return False, "Wait"

    def process_str(self):
        assert self.is_trainer
        return f"{self.total_batch_controls - len(self.batch_control_meta)}/{self.total_batch_controls}"

class BatchControlManager(BaseManager):
    pass


def start_batch_control_server(worker_node_id, config, storage_config: StorageConfig):
    """
    Start a batch control in current process. Call this method remotely.
    """
    # initialize batch control
    batch_control = BatchControl(worker_node_id, config, storage_config)
    BatchControlManager.register('get_batch_control_proxy', callable=lambda: batch_control)
    print("[Batch Control] Batch control initialized.")

    # start server
    batch_control_connection = storage_config.batch_control_connection
    manager = BatchControlManager(address=(batch_control_connection.ip,
                                            batch_control_connection.port),
                                   authkey=bytes(batch_control_connection.authkey))
    server = manager.get_server()
    print(f"[Batch Control] Starting batch control server at port {batch_control_connection.port}.")
    server.serve_forever()


def get_batch_control(storage_config: StorageConfig):
    """
    Get connection to a batch control server.
    """
    BatchControlManager.register('get_batch_control_proxy')
    batch_control_connection = storage_config.batch_control_connection
    batch_control_manager = BatchControlManager(address=(batch_control_connection.ip,
                                                           batch_control_connection.port),
                                                  authkey=batch_control_connection.authkey)
    batch_control_connected = False
    while not batch_control_connected:
        try:
            batch_control_manager.connect()
            batch_control_connected = True
        except ConnectionRefusedError:
            print(f"[(pid={os.getpid()})] Shared storage server not ready, retry in 1 sec.")
            time.sleep(1)
    batch_control = batch_control_manager.get_batch_control_proxy()
    return batch_control
