import sys
sys.path.append("./")
import zmq
import os
import time
import json
import pickle
import traceback
from learner.basic_server import BasicServer
from learner.utils import setup_logger, zmq_nonblocking_multipart_recv, zmq_nonblocking_recv


class ConfigServer(BasicServer):
    def __init__(self):
        BasicServer.__init__(self)
        log_path = os.path.join(self.log_dir, 'config_server_log')
        self.log_handler = setup_logger("ConfigServer", log_path)

        # model requester
        self.model_requester = self.context.socket(zmq.ROUTER)
        self.model_requester.set_hwm(1000000)
        self.model_requester.bind("tcp://%s:%d" % (
            self.json_config["config_server_address"],
            self.json_config["config_server_request_model_port"]))
        self.poller.register(self.model_requester, zmq.POLLIN)

        # model receiver
        self.model_receiver = self.context.socket(zmq.PULL)
        self.model_receiver.set_hwm(1000000)
        self.model_receiver.bind("tcp://%s:%d" % (
            self.json_config["config_server_address"],
            self.json_config["config_server_model_update_port"]))
        self.poller.register(self.model_receiver, zmq.POLLIN)

        # publish hot update info
        self.next_hot_update_time = time.time()
        self.hot_update_config = None

        self.latest_model_info = {
            "self": {},
            "self_delay": {},
            "opponent": {},
            "opponent_delay": {}
        }

        self.next_update_delay_model_time = {
            "self": time.time(),
            "opponent": time.time()
        }

        self.delay_update_time = self.json_config["delay_update_interval"]
        self.hot_update_file = self.json_config['hot_update_file']

    def hot_update(self):
        if time.time() > self.next_hot_update_time:
            try:
                with open(self.hot_update_file) as f:
                    hot_update_json = f.read()
                update_dict = json.loads(hot_update_json)

                if update_dict.get("hot_update", False):
                    self._hot_update(update_dict)
                    self.hot_update_config = update_dict
                else:
                    self.hot_update_config = None
            except Exception as e:
                self.log_handler.error(traceback.format_exc())
            self.next_hot_update_time = time.time() + 60

    def run(self):
        while True:
            self.hot_update()
            sockets = dict(self.poller.poll(timeout=100))

            if self.model_requester in sockets and sockets[self.model_requester] == zmq.POLLIN:
                raw_data = zmq_nonblocking_multipart_recv(self.model_requester)
                for raw in raw_data:
                    try:
                        requester = pickle.loads(raw[-1])
                        if requester["model_side"] in ["self", "opponent", "self_delay", "opponent_delay"]:
                            target_model_info = self.latest_model_info[requester["model_side"]]
                            if len(target_model_info) > 0:
                                if self.hot_update_config is not None:
                                    target_model_info["hot_update_config"] = self.hot_update_config
                                raw[-1] = pickle.dumps(target_model_info)
                                self.model_requester.send_multipart(raw)
                        else:
                            self.log_handler.info("model_requester model_side error!")
                    except Exception as e:
                        self.log_handler.error(f"Model requester fail to unpickle message. Error: {str(e)}. Skipping...")
                        continue

            # update model information
            if self.model_receiver in sockets and sockets[self.model_receiver] == zmq.POLLIN:
                raw_data = zmq_nonblocking_recv(self.model_receiver)
                for raw in raw_data:
                    try:
                        model_info = pickle.loads(raw)
                        self.log_handler.info(
                            "recv new model, side %s, url %s" % (model_info["model_side"], model_info["url"]))
                        if model_info["model_side"] in ["self", "opponent"]:
                            self.latest_model_info[model_info["model_side"]] = {
                                "url": model_info["url"],
                                "time": time.time()
                            }
                        else:
                            self.log_handler.error("model_receiver model_side error!")
                    except Exception as e:
                        self.log_handler.error(f"Model receiver failed to unpickle message. Error: {str(e)}. Skipping...")
                        continue


if __name__ == '__main__':
    server = ConfigServer()
    server.run()


