import sys
import os
import logging
import json
import zmq
import pickle
import time
import io
import uuid
import traceback

sys.path.append("./")
import lz4.frame
import numpy as np
from worker.rollout import SamplerRollout
from worker.agent import SamplerAgent
from worker.statistics import StatisticsUtils

FORMAT = '%(asctime)s-%(name)s-%(levelname)s-%(message)s'
logging.basicConfig(format=FORMAT)


class SampleWorker(object):
    def __init__(
            self,
            args
    ):
        self.uuid = str(uuid.uuid4())
        self.statistic = StatisticsUtils()
        with open('config.json') as f:
            json_str = f.read()
        self.json_config = json.loads(json_str)

        self.agent_side = 'self'
        self.cwd = os.getcwd()
        self.sampler_method = self.json_config["sampler_method"]
        # if self.evaluation_mode == 1:
        #     self.sampler_method = 'argmax'
        self.agent = SamplerAgent(self, sampler_method=self.sampler_method)
        self.rollout = SamplerRollout(self, self.agent, statistic=self.statistic)

        self.logger = logging.getLogger('sampler.py')
        self.logger.setLevel(logging.DEBUG)

        self.context = zmq.Context()

        self.latest_model_requester = self.context.socket(zmq.REQ)
        self.latest_model_requester.connect("tcp://%s:%d" % (
            self.json_config["config_server_address"],
            self.json_config["config_server_request_model_port"]))
        self.need_update_eval_model = False
        self.latest_update_interval = self.json_config["latest_update_interval"]

        # build the log sender
        self.log_sender = self.context.socket(zmq.PUSH)
        self.log_sender.connect("tcp://%s:%d" % (
            self.json_config["log_server_address"],
            self.json_config["log_server_port"]))

        # build the data sender
        self.data_sender = self.context.socket(zmq.PUSH)
        data_server_num_per_machine = \
            self.json_config["gpu_num_per_machine"] * self.json_config["data_server_to_learner_num"]
        total_data_server = \
            len(self.json_config["learner_machine_ips"][self.agent_side]["machines"]) * data_server_num_per_machine
        choose_index = np.random.randint(0, total_data_server)
        machine_index, port_num = divmod(choose_index, data_server_num_per_machine)
        target_ip = self.json_config["learner_machine_ips"][self.agent_side]["machines"][machine_index]
        target_port = self.json_config["leaner_port_start"] + port_num
        self.logger.info("total process %d, choose index %d, target_ip %s, target port %d" % (
            total_data_server, choose_index, target_ip, target_port))
        self.data_sender.connect("tcp://%s:%d" % (target_ip, target_port))

        self.dragon = os.environ.get("SUPERNODE", None)
        self.logger.info('use dragonfly supernode {}'.format(self.dragon))

        self.latest_model_timestamp = time.time()

        self.dragonfly_limit = 100
        self.dragonfly_path = '/home/dragonfly/Dragonfly_1.0.2_linux_amd64/dfget'

        self.last_download_time = 0

    def fetch_self_model(self):
        self.latest_model_requester.send(pickle.dumps({"model_side": self.agent_side}))

        s_time = time.time()
        model_raw = self.latest_model_requester.recv()
        self.statistic.append("model_request_time", time.time() - s_time)

        model_info = pickle.loads(model_raw)

        if model_info is not None and "hot_update_config" in model_info:
            if model_info["hot_update_config"]['hot_update']:
                hot_update_config = model_info["hot_update_config"]
                self.json_config.update(hot_update_config)
                self.rollout.json_config.update(hot_update_config)

        if self.latest_model_timestamp != model_info["time"] and (
                time.time() - self.last_download_time > self.latest_update_interval + 1):
            self.statistic.append("model_update_interval", model_info["time"] - self.latest_model_timestamp)
            self.latest_model_timestamp = model_info["time"]

            url = model_info["url"]
            self.logger.info("get self model url %s" % url)

            if os.path.exists(self.cwd + '/model.file'):
                os.remove(self.cwd + '/model.file')

            s_time = time.time()
            os.system(
                self.dragonfly_path + ' -u' +
                ' ' + url + ' -o ' + self.cwd +
                '/model.pth --node ' + str(self.dragon) +
                ' --locallimit={0}M --totallimit={0}M'.format(
                    self.dragonfly_limit
                ))
            self.statistic.append("p2p_download_time", time.time() - s_time)
            with open(self.cwd + '/model.pth', 'rb') as m:
                models = io.BytesIO(m.read())
            self.agent.fetch_model_parameters(models)
            self.logger.info("load model parameters")

            self.last_download_time = time.time()
            return True
        else:
            self.logger.debug('same self model, skip updating model')
            return False

    def run(self):
        fetch_model = self.fetch_self_model()

        for i in range(self.rollout.n_sample):
            s_time = time.time()
            self.logger.info("sample one trajectory")
            data, result, reward, info = self.rollout.sample_one_traj()
            self.statistic.append("trajectory_running_time", time.time() - s_time)

            compressed_data = lz4.frame.compress(pickle.dumps(data))

            self.data_sender.send(compressed_data)

            if fetch_model:
                self.logger.info("send result to log server")

                result_info = {
                    "docker_id": self.uuid,
                    "model/state_max_prob": self.statistic.get_avg_value("state_max_prob"),

                    "sampler/sampler_model_request_time": self.statistic.get_avg_value("model_request_time"),
                    "sampler/sampler_model_update_interval": self.statistic.get_avg_value(
                        "model_update_interval"),
                    "sampler/p2p_download_time": self.statistic.get_avg_value("p2p_download_time"),
                    "sampler/trajectory_running_time": self.statistic.get_avg_value(
                        "trajectory_running_time"),

                    "sampler/ep_return": self.statistic.get_avg_value("ep_return"),
                    "sampler/ep_length": self.statistic.get_avg_value("ep_length"),
                    "sampler/ep_time": self.statistic.get_avg_value("ep_time"),

                    "sampler/speed": self.statistic.get_avg_value("speed"),
                    "sampler/ttc": self.statistic.get_avg_value("ttc"),

                    "sampler/act_left_ratio": self.statistic.get_avg_value("act_left_ratio"),
                    "sampler/act_idle_ratio": self.statistic.get_avg_value("act_idle_ratio"),
                    "sampler/act_right_ratio": self.statistic.get_avg_value("act_right_ratio"),
                    "sampler/act_fast_ratio": self.statistic.get_avg_value("act_fast_ratio"),
                    "sampler/act_slow_ratio": self.statistic.get_avg_value("act_slow_ratio"),

                    "sampler/lane_left_ratio": self.statistic.get_avg_value("lane_left_ratio"),
                    "sampler/lane_right_ratio": self.statistic.get_avg_value("lane_right_ratio"),
                    "sampler/lane_mid_ratio": self.statistic.get_avg_value("lane_mid_ratio"),

                    "reward/speed": self.statistic.get_avg_value("reward_speed"),
                    "reward/ttc": self.statistic.get_avg_value("reward_ttc"),
                    "reward/change": self.statistic.get_avg_value("reward_change"),
                    "reward/left": self.statistic.get_avg_value("reward_left"),
                    "reward/right": self.statistic.get_avg_value("reward_right"),
                    "reward/mid": self.statistic.get_avg_value("reward_mid"),
                }

                # delete meaningless values
                keys_to_delete = [key for key, value in result_info.items() if value == 0]
                for key in keys_to_delete:
                    del result_info[key]

                self.log_sender.send(pickle.dumps([result_info]))
                self.statistic.clear()

    def run_loop(self):
        try:
            while True:
                self.run()
        except Exception as e:
            error_str = traceback.format_exc()
            logging.error(e)
            logging.error(error_str)
            send_message = {'error_log': error_str}
            p = pickle.dumps([send_message])
            self.log_sender.send(p)
            time.sleep(3)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='parameters：')
    args = parser.parse_args()
    worker = SampleWorker(args)
    worker.run_loop()
