import os
import time
import SMOS
from multiprocessing.managers import BaseManager

from core.storage_config import StorageConfig
from core.replay_buffer import get_replay_buffer
from core.shared_storage import get_shared_storage


class WatchdogServer(object):
    def __init__(self):
        self.reanalyze_batch_count = 0
        self.training_step_count = 0
        self.drop_ratio = 0.5

    def increase_reanalyze_batch_count(self):
        self.reanalyze_batch_count += 1

    def get_reanalyze_batch_count(self):
        return self.reanalyze_batch_count

    def increase_training_step_count(self):
        self.training_step_count += 1

    def get_training_step_count(self):
        return self.training_step_count
    
    def update_drop_ratio(self, drop_ratio):
        self.drop_ratio = self.drop_ratio * 0.8 + drop_ratio * 0.2
    
    def get_drop_ratio(self):
        return self.drop_ratio


class WatchdogServerManager(BaseManager):
    pass


def start_watchdog_server(storage_config: StorageConfig):
    """
    Start a watchdog server. Call this method remotely.
    """
    # initialize watchdog server
    watchdog_server = WatchdogServer()
    WatchdogServerManager.register('get_watchdog_server', callable=lambda: watchdog_server)
    print("[Watchdog Server] Watchdog server initialized.")

    # start server
    watchdog_connection = storage_config.watchdog_server_connection
    manager = WatchdogServerManager(address=(watchdog_connection.ip,
                                             watchdog_connection.port),
                                    authkey=bytes(watchdog_connection.authkey))
    server = manager.get_server()
    print(f"[Watchdog Server] Starting watchdog server at port {watchdog_connection.port}.")
    server.serve_forever()


def get_watchdog_server(storage_config: StorageConfig):
    """
    Get connection to a watchdog server.
    """
    WatchdogServerManager.register('get_watchdog_server')
    watchdog_server_connection = storage_config.watchdog_server_connection
    watchdog_server_manager = WatchdogServerManager(address=(watchdog_server_connection.ip,
                                                             watchdog_server_connection.port),
                                                    authkey=watchdog_server_connection.authkey)
    watchdog_server_connected = False
    while not watchdog_server_connected:
        try:
            watchdog_server_manager.connect()
            watchdog_server_connected = True
        except ConnectionRefusedError:
            print(f"[(pid={os.getpid()})] Watchdog server not ready, retry in 1 sec.")
            time.sleep(1)
    watchdog_server = watchdog_server_manager.get_watchdog_server()
    return watchdog_server


def start_watchdog(config, storage_config: StorageConfig, mode=None):
    """
    Start a watchdog that monitors training statistics. Call this method remotely.
    mode: trainer / worker
    """
    # get watchdog server
    WatchdogServerManager.register('get_watchdog_server')
    watchdog_server_connection = storage_config.watchdog_server_connection
    watchdog_server_manager = WatchdogServerManager(address=(watchdog_server_connection.ip,
                                                             watchdog_server_connection.port),
                                                    authkey=bytes(watchdog_server_connection.authkey))
    watchdog_server_connected = False
    while not watchdog_server_connected:
        try:
            watchdog_server_manager.connect()
            watchdog_server_connected = True
        except ConnectionRefusedError:
            print(f"Watchdog server not ready, retry in 1 sec. Process {os.getpid()}.")
            time.sleep(1)
    watchdog_server = watchdog_server_manager.get_watchdog_server()

    if mode == "worker":
    # replay buffer
        replay_buffer = get_replay_buffer(storage_config=storage_config)
        shared_storage = get_shared_storage(storage_config=storage_config)

    # get SMOS client
    smos_client = SMOS.Client(connection=storage_config.smos_connection)

    # start watching statistics
    last_batch_count = 0
    last_training_step_count = 0
    while True:

        # log speed statistics
        time.sleep(5)
        batch_count = watchdog_server.get_reanalyze_batch_count()
        training_step_count = watchdog_server.get_training_step_count()
        if training_step_count < config.training_steps + config.last_steps:
            print("*********************** Watchdog ***********************")
            if mode == "worker":
                batches_per_10sec = batch_count - last_batch_count

                # reanalyze speed control
                drop_ratio = max(0, batches_per_10sec - storage_config.batches_per_10sec) / (batches_per_10sec + 1e-6)
                watchdog_server.update_drop_ratio(drop_ratio)

                print(f"Reanalyze speed: {batches_per_10sec} batches/10sec.")
                print("Drop ratio: {:.4f}".format(watchdog_server.get_drop_ratio()))
                print(f"Replay buffer update priority frequency: {replay_buffer.get_num_refreshes() / (shared_storage.get_counter() + 1e-6)} times/step")
            if mode == "trainer":
                print(f"Training speed: {training_step_count - last_training_step_count} steps/10sec", flush=True)
            print("********************************************************")
        last_batch_count = batch_count
        last_training_step_count = training_step_count

        # zombie killer, only worker needs this
        time.sleep(5)
        if mode == "worker":
            zombie_list = []
            kill_zombie_count = 0
            remaining_zombie_list = []
            while True:
                status, handle, idx = smos_client.pop_from_object(name=storage_config.zombie_queue_name)
                if status == SMOS.SMOS_FAIL:
                    break
                else:
                    zombie_list.append(idx)
                    smos_client.free_handle(object_handle=handle)
            for zombie_idx in zombie_list:
                status = smos_client.delete_entry(name=storage_config.replay_buffer_name, entry_idx=zombie_idx)
                if status == SMOS.SMOS_PERMISSION_DENIED:
                    smos_client.push_to_object(name=storage_config.zombie_queue_name, data=zombie_idx)
                    remaining_zombie_list.append(zombie_idx)
                else:
                    kill_zombie_count += 1
            _, replay_buffer_size = smos_client.get_entry_count(name=storage_config.replay_buffer_name)
            print("********************* Zombie Killer ********************")
            print(f"Replay buffer size: {replay_buffer_size}")
            print(f"Killed: {kill_zombie_count} zombies, remaining: {remaining_zombie_list}")
            print("********************************************************")
