import sys

sys.path.append("./")
import queue
import zmq
import os
import time
import json
import pickle
from learner.summary_util import SummaryLog
from learner.utils import setup_logger
from learner.utils import *


class LogServer:
    def __init__(self):
        context = zmq.Context()
        self.poller = zmq.Poller()

        with open('config.json') as f:
            json_str = f.read()
        self.json_config = json.loads(json_str)
        self.log_dir = self.json_config["log_dir"]

        self.receiver = context.socket(zmq.PULL)
        self.receiver.bind("tcp://%s:%d" % (self.json_config["log_server_address"],
                                            self.json_config["log_server_port"]))
        self.poller.register(self.receiver, zmq.POLLIN)

        self.log_basic_info = setup_logger("basic", os.path.join(self.log_dir, "log_server_log"))

        create_path(os.path.join(self.log_dir, "summary_log"))
        self.summary_logger = SummaryLog(os.path.join(self.log_dir, "summary_log"), self.json_config)

        self.active_docker_dict = {}
        self.next_cal_docker_time = time.time()

    def summary_definition(self):
        sampler_avg_num = 100
        self.sampler_framework_keys = [
            "sampler/sampler_model_request_time",
            "sampler/sampler_model_update_interval",
            "sampler/error_per_min",
            "sampler/p2p_download_time",
            "sampler/trajectory_running_time"
        ]
        self.summary_logger.add_tag('sampler/sampler_model_request_time', sampler_avg_num, "avg")
        self.summary_logger.add_tag('sampler/sampler_model_update_interval', sampler_avg_num, "avg")
        self.summary_logger.add_tag('sampler/p2p_download_time', sampler_avg_num, "avg")
        self.summary_logger.add_tag('sampler/trajectory_running_time', sampler_avg_num, "avg")
        self.summary_logger.add_tag('sampler/error_per_min', 0, "time_total", time_threshold=60)

        self.summary_logger.add_tag('model/state_max_prob', sampler_avg_num, "avg")

        self.sampler_episode_keys = [
            "sampler/ep_return",
            "sampler/ep_length",
            "sampler/ep_time",

            "sampler/speed",
            "sampler/ttc",
            
            "sampler/act_left_ratio",
            "sampler/act_idle_ratio",
            "sampler/act_right_ratio",
            "sampler/act_fast_ratio",
            "sampler/act_slow_ratio",

            "sampler/lane_left_ratio",
            "sampler/lane_right_ratio",
            "sampler/lane_mid_ratio"
        ]
        for key in self.sampler_episode_keys:
            self.summary_logger.add_tag(key, sampler_avg_num, "avg")

        self.reward_episode_keys = [
            "reward/speed",
            "reward/ttc",
            "reward/change",
            "reward/left",
            "reward/right",
            "reward/mid"
        ]
        for key in self.reward_episode_keys:
            self.summary_logger.add_tag(key, sampler_avg_num, "avg")

        self.data_server_keys = [
            'data_server/dataserver_recv_instance_per_min',
            "data_server/dataserver_parse_time_per_minutes",
            "data_server/dataserver_socket_time_per_minutes",
            "data_server/dataserver_sampling_time_per_min",
            "data_server/active_docker_count"
        ]
        total_data_server = len(self.json_config["learner_machine_ips"]["self"]["machines"]) * \
                            self.json_config["gpu_num_per_machine"] * \
                            self.json_config["data_server_to_learner_num"]
        self.summary_logger.add_tag('data_server/dataserver_recv_instance_per_min', total_data_server, "total")
        self.summary_logger.add_tag('data_server/dataserver_parse_time_per_minutes', 1, "avg")
        self.summary_logger.add_tag('data_server/dataserver_socket_time_per_minutes', 1, "avg")
        self.summary_logger.add_tag('data_server/dataserver_sampling_time_per_min', 1, "avg")
        self.summary_logger.add_tag('data_server/active_docker_count', 1, "avg")

        self.leaner_server_keys = [
            "leaner_server/sgd_round_per_min",
            "leaner_server/sgd_total_time",
            "leaner_server/wait_data_time_per_min"
        ]
        self.summary_logger.add_tag('leaner_server/sgd_round_per_min', 1, "avg")
        self.summary_logger.add_tag('leaner_server/sgd_total_time', 1, "avg")
        self.summary_logger.add_tag('leaner_server/wait_data_time_per_min', 1, "avg")

        self.model_keys = [
            "model/entropy_mean",
            "model/entropy",
            "model/value_loss",
            "model/policy_loss",
            "model/total_loss",
            "model/avg_q_value",
            "model/advantage",
            "model/advantage_norm",
            "model/advantage_std",
            "model/approx_kl",
            "model/clip_frac",
            "model/ratio",
        ]
        for key in self.model_keys:
            self.summary_logger.add_tag(key, 10, "avg")

    def log_detail(self, data):
        for field_key, value in data.items():
            if field_key.find('/') == -1:
                if field_key == "docker_id":
                    self.active_docker_dict[value] = 1
                continue
            else:
                self.summary_logger.add_summary(field_key, value)

    def run(self):
        self.summary_definition()

        while True:
            if time.time() > self.next_cal_docker_time:
                self.next_cal_docker_time = time.time() + 60 * 3
                self.summary_logger.add_summary("data_server/active_docker_count", len(self.active_docker_dict))
                self.active_docker_dict = {}

            self.summary_logger.generate_time_data_output(self.log_basic_info)
            socks = dict(self.poller.poll(timeout=100))

            if self.receiver in socks and socks[self.receiver] == zmq.POLLIN:
                raw_data_list = []
                while True:
                    try:
                        data = self.receiver.recv(zmq.NOBLOCK)
                        raw_data_list.append(data)
                    except zmq.ZMQError as e:
                        if type(e) != zmq.error.Again:
                            self.log_basic_info.warn("recv zmq {}".format(e))
                        break

                for raw_data in raw_data_list:
                    try:
                        data = pickle.loads(raw_data)
                        for log in data:
                            # print(log)
                            if "error_log" in log:
                                self.log_basic_info.error("client_error, %s" % (log["error_log"]))
                                self.summary_logger.add_summary('sampler/error_per_min', 1, timestamp=time.time())
                            else:
                                self.log_detail(log)
                    except Exception as e:
                        self.log_basic_info.error(f"Log Server failed unpickle message. Error: {str(e)}. Skipping...")
                        continue


if __name__ == '__main__':
    server = LogServer()
    server.run()
