import copy
import queue
import os
import time
import torch
import SMOS
from collections import defaultdict
from multiprocessing.managers import BaseManager
from SMOS_utils import RWLock

from core.storage_config import StorageConfig
from core.model import state_dict_to_cpu
from core.meta_data_manager import MetaDataSharedMemoryManager

class SharedStorage(object):
    def __init__(self, model, target_model, config, storage_config, smos_client):
        """Shared storage for models and others
        Parameters
        ----------
        model: any
            models for self-play (update every checkpoint_interval)
        target_model: any
            models for reanalyzing (update every target_model_interval)
        """
        self.step_counter = 0
        self.test_counter = 0
        self.ori_reward_log = []
        self.reward_log = []
        self.reward_max_log = []
        self.test_dict_log = {}
        self.eps_lengths = []
        self.eps_lengths_max = []
        self.temperature_log = []
        self.visit_entropies_log = []
        self.priority_self_play_log = []
        self.last_eps_logs = 0, 0, 0
        self.distributions_log = {}
        self.start = False
        self.storage_config = storage_config

        # model & target model
        self.model_version = 0
        self.weights_location = None
        self.current_cycle_idx = 0
        self.weights_storage_cycle_size = storage_config.weights_storage_cycle_size

        # smos
        self.smos_client = smos_client

        # meta-data manager
        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)

        if model is not None:
            self.set_weights(model.get_weights(), 0, num=self.weights_storage_cycle_size, pop=False)

        # start training control
        self.ready_worker_nodes = []
        self.ready_cpu_workers = []
        self.ready_gpu_workers = []
        self.ready_value_updaters = []
        self.ready_priority_updaters = []

    def set_start_signal(self):
        print("[Shared Storage] Start Training")
        self.start = True

    def get_start_signal(self):
        return self.start

    def get_weights_location(self):
        weights_location = self.weights_location
        return weights_location

    def set_weights(self, weights, version, num=1, pop=True):
        print("[Shared Storage] set weights")
        for _ in range(num):
            idx = self.current_cycle_idx
            weights_storage_name =  self.storage_config.weights_storage_name + str(idx)

            if pop:
                # pop out one old weight
                while True:
                    status, handle, _ = self.smos_client.pop_from_object(name=weights_storage_name)
                    if not status == SMOS.SMOS_SUCCESS:
                        time.sleep(0.05)
                    else:
                        self.smos_client.free_handle(object_handle=handle)
                        break

            # push new weights
            status, model_entry_index = self.smos_client.push_to_object(name=weights_storage_name,
                                                        data=[weights, version])
            assert status == SMOS.SMOS_SUCCESS
            self.weights_location = (weights_storage_name, model_entry_index)
            self.meta_data_manager.get("weights_location")[0] = idx * 10000000 + model_entry_index

            self.current_cycle_idx = (self.current_cycle_idx + 1) % self.weights_storage_cycle_size
        print("[Shared Storage] finish set weights", self.weights_location)
        return None

    def incr_counter(self):
        self.step_counter += 1
        self.meta_data_manager.get("trained_steps")[0] = self.step_counter

    def get_counter(self):
        return self.step_counter

    def set_counter(self, val):
        self.step_counter = val
        self.meta_data_manager.get("trained_steps")[0] = self.step_counter

    def set_data_worker_logs(self, last_eps_len, eps_len, eps_len_max, last_eps_ori_reward, last_eps_reward, eps_ori_reward, eps_reward, eps_reward_max, temperature, visit_entropy, priority_self_play, distributions):
        self.eps_lengths.append(eps_len)
        self.eps_lengths_max.append(eps_len_max)
        self.ori_reward_log.append(eps_ori_reward)
        self.reward_log.append(eps_reward)
        self.reward_max_log.append(eps_reward_max)
        self.temperature_log.append(temperature)
        self.visit_entropies_log.append(visit_entropy)
        self.priority_self_play_log.append(priority_self_play)
        self.last_eps_logs = (last_eps_len, last_eps_ori_reward, last_eps_reward)

        for key, val in distributions.items():
            if key not in self.distributions_log.keys():
                self.distributions_log[key] = []
            self.distributions_log[key] += val

    def add_test_log(self, test_counter, test_dict, test_label):
        print("[Shared Storage] add test log")
        if test_label is None:
            self.test_counter = test_counter
            for key, val in test_dict.items():
                if key not in self.test_dict_log.keys():
                    self.test_dict_log[key] = []
                self.test_dict_log[key].append(val)
        else:
            if test_label not in self.test_dict_log:
                self.test_dict_log[test_label] = defaultdict(list)
            for key, val in test_dict.items():
                self.test_dict_log[test_label].append(val)
        print("[Shared Storage] finish add test log")

    def get_worker_logs(self):
        if len(self.reward_log) > 0:
            ori_reward = sum(self.ori_reward_log) / len(self.ori_reward_log)
            reward = sum(self.reward_log) / len(self.reward_log)
            reward_max = sum(self.reward_max_log) / len(self.reward_max_log)
            eps_lengths = sum(self.eps_lengths) / len(self.eps_lengths)
            eps_lengths_max = sum(self.eps_lengths_max) / len(self.eps_lengths_max)
            temperature = sum(self.temperature_log) / len(self.temperature_log)
            visit_entropy = sum(self.visit_entropies_log) / len(self.visit_entropies_log)
            priority_self_play = sum(self.priority_self_play_log) / len(self.priority_self_play_log)
            distributions = self.distributions_log

            self.ori_reward_log = []
            self.reward_log = []
            self.reward_max_log = []
            self.eps_lengths = []
            self.eps_lengths_max = []
            self.temperature_log = []
            self.visit_entropies_log = []
            self.priority_self_play_log = []
            self.distributions_log = {}

        else:
            ori_reward = None
            reward = None
            reward_max = None
            eps_lengths = None
            eps_lengths_max = None
            temperature = None
            visit_entropy = None
            priority_self_play = None
            distributions = None

        if len(self.test_dict_log) > 0:
            test_dict = self.test_dict_log

            self.test_dict_log = {}
            test_counter = self.test_counter
        else:
            test_dict = None
            test_counter = None

        return ori_reward, reward, reward_max, eps_lengths, eps_lengths_max, test_counter, test_dict, temperature, visit_entropy, priority_self_play, distributions, self.last_eps_logs

    def get_test_dict_log(self):
        return copy.deepcopy(self.test_dict_log)
    
    def add_ready_worker_node(self, worker_node_id):
        if worker_node_id not in self.ready_worker_nodes:
            self.ready_worker_nodes.append(worker_node_id)

    def get_num_ready_worker_nodes(self):
        return len(self.ready_worker_nodes)

    def add_ready_cpu_worker(self, worker_node_id):
        self.ready_cpu_workers.append(worker_node_id)
    
    def add_ready_gpu_worker(self, worker_node_id):
        self.ready_gpu_workers.append(worker_node_id)
    
    def add_ready_value_updater(self, worker_node_id):
        self.ready_value_updaters.append(worker_node_id)
    
    def add_ready_priority_updater(self, worker_node_id):
        self.ready_priority_updaters.append(worker_node_id)
    
    def get_num_ready_cpu_workers(self):
        return len(self.ready_cpu_workers)

    def get_num_ready_gpu_workers(self):
        return len(self.ready_gpu_workers)
    
    def get_num_ready_value_updaters(self):
        return len(self.ready_value_updaters)
    
    def get_num_ready_priority_updaters(self):
        return len(self.ready_priority_updaters)

class SharedStorageManager(BaseManager):
    pass


def start_shared_storage_server(config, storage_config: StorageConfig, model, target_model):
    """
    Start a shared storage in current process. Call this method remotely.
    """
    # initialize shared storage
    smos_client=SMOS.Client(connection=storage_config.smos_connection)
    share_storage = SharedStorage(model=model, target_model=target_model, config=config, storage_config=storage_config, smos_client=smos_client)
    SharedStorageManager.register('get_shared_storage_proxy', callable=lambda: share_storage)
    print("[Shared storage] Shared storage initialized.")

    # start server
    shared_storage_connection = storage_config.shared_storage_connection
    manager = SharedStorageManager(address=(shared_storage_connection.ip,
                                            shared_storage_connection.port),
                                   authkey=bytes(shared_storage_connection.authkey))
    server = manager.get_server()
    print(f"[Shared storage] Starting shared storage server at port {shared_storage_connection.port}.")
    server.serve_forever()


def get_shared_storage(storage_config: StorageConfig):
    """
    Get connection to a shared storage server.
    """
    SharedStorageManager.register('get_shared_storage_proxy')
    shared_storage_connection = storage_config.shared_storage_connection
    shared_storage_manager = SharedStorageManager(address=(shared_storage_connection.ip,
                                                           shared_storage_connection.port),
                                                  authkey=shared_storage_connection.authkey)
    shared_storage_connected = False
    while not shared_storage_connected:
        try:
            shared_storage_manager.connect()
            shared_storage_connected = True
        except ConnectionRefusedError:
            print(f"[(pid={os.getpid()})] Shared storage server not ready, retry in 1 sec.")
            time.sleep(1)
    shared_storage = shared_storage_manager.get_shared_storage_proxy()
    return shared_storage

def read_weights(meta_data_manager, smos_client, shared_storage, storage_config: StorageConfig):
    while True:
        weights_location = meta_data_manager.get(item_name="weights_location")[0]

        if weights_location is None:
            time.sleep(0.1)
            continue
        
        weights_storage_idx, weights_entry_idx = weights_location // 10000000, weights_location % 10000000
        weights_storage_name = storage_config.weights_storage_name + str(weights_storage_idx)
        
        status, handle_batch, data_batch = smos_client.batch_read_from_object(name=weights_storage_name,
                                                                                            entry_idx_batch=[weights_entry_idx])
        
        if not status == SMOS.SMOS_SUCCESS:
            print("Read weights error, wait and retry...")
            time.sleep(0.1)
            continue
        # clean up batch from replay buffer
        smos_client.batch_release_entry(object_handle_batch=handle_batch)
        break
    weights, version = data_batch[0]
    return weights, version
