import sys
sys.path.append("./")
import os
import shutil
import zmq
import time
import pickle
import random
import argparse
import socket
import lz4.frame
import numpy as np
import pyarrow as pa
import pyarrow.plasma as plasma
from worker.instance import TrainingSet
from learner.utils import setup_logger
from learner.basic_server import BasicServer


class GpuDataServer(BasicServer):
    def __init__(self, args):
        BasicServer.__init__(self)

        self.global_rank = args.rank
        self.world_size = args.world_size
        self.data_server_local_rank = args.data_server_local_rank
        self.self_play_side = args.model_side
        self.test_mode = args.test_mode

        self.gpu_num_per_machine = self.json_config["gpu_num_per_machine"]
        self.local_rank = self.global_rank % self.gpu_num_per_machine

        target_machines = self.json_config["learner_machine_ips"][self.self_play_side]["machines"]
        machine_index = self.global_rank // self.gpu_num_per_machine
        self.server_ip = target_machines[machine_index]

        log_path = os.path.join(self.log_dir, self.server_ip,
                                'gpu_data_log/%d/%d' % (self.local_rank, self.data_server_local_rank))
        self.log_handler = setup_logger("GpuDataServer", log_path)

        self.server_port = self.json_config["leaner_port_start"] + \
                           self.local_rank * self.json_config["data_server_to_learner_num"] + \
                           self.data_server_local_rank

        # receive training instances
        self.receiver = self.context.socket(zmq.PULL)
        self.receiver.set_hwm(1000000)
        self.receiver.bind("tcp://%s:%d" % (self.server_ip, self.server_port))
        self.poller.register(self.receiver, zmq.POLLIN)

        if self.global_rank == 0 and self.data_server_local_rank == 0:
            # root gpu server publish start training info
            self.root_publisher = self.context.socket(zmq.PUB)
            self.root_publisher.bind("tcp://%s:%d" % (
                self.json_config["root_gpu_ip"], self.json_config["root_gpu_pub_start_port"]))
        else:
            # subscribe root gpu server's start training info
            self.root_subscriber = self.context.socket(zmq.SUB)
            self.root_subscriber.connect("tcp://%s:%d" % (
                self.json_config["root_gpu_ip"], self.json_config["root_gpu_pub_start_port"]))
            self.root_subscriber.setsockopt_string(zmq.SUBSCRIBE, "")
            self.poller.register(self.root_subscriber, zmq.POLLIN)

        self.batch_size = self.json_config["batch_size"]
        self.traj_len = self.json_config["traj_len"]
        self.pool_capacity = int(self.batch_size * 2)
        if self.test_mode:
            self.pool_capacity = 256
        self.training_set = TrainingSet(max_capacity=self.pool_capacity)

        self.start_training = False
        self.socket_time_list = []
        self.parse_data_time_list = []
        self.recv_training_instance_count = 0
        self.next_print_receive_time = time.time()
        self.sampling_time_list = []


        self.plasma_data_id = plasma.ObjectID(
            5 * bytes(str(self.global_rank * 10 + self.data_server_local_rank + 1000), encoding="utf-8"))
        self.plasma_client = plasma.connect("{}".format(os.path.join(os.path.expanduser("~"), "plasma")), 2)
        self.data_server_sampling_interval = self.json_config["data_server_sampling_interval"]
        self.next_sampling_time = time.time()

    def receive_data(self, socks):
        raw_data_list = []
        if self.receiver in socks and socks[self.receiver] == zmq.POLLIN:
            s_time = time.time()
            while True:
                try:
                    data = self.receiver.recv(zmq.NOBLOCK)
                    raw_data_list.append(data)
                except zmq.ZMQError as e:
                    if type(e) != zmq.error.Again:
                        self.log_handler.warn("recv zmq {}".format(e))
                    break
            if len(raw_data_list) > 0:
                self.socket_time_list.append(time.time() - s_time)

            cur_recv_total = 0
            s_time = time.time()
            for raw_data in raw_data_list:
                try:
                    all_data = pickle.loads(lz4.frame.decompress(raw_data))
                    self.training_set.append_instance(all_data)
                    cur_recv_total += len(all_data)
                except Exception as e:
                    self.log_handler.error(f"mpi {self.global_rank}, local rank {self.data_server_local_rank}, \
                                           Data server failed to decompress or unpickle message. Error: {str(e)}. Skipping...")
                    continue
            self.recv_training_instance_count += cur_recv_total
            self.training_set.fit_max_size()

            if len(raw_data_list) > 0:
                self.parse_data_time_list.append(time.time() - s_time)

        if time.time() > self.next_print_receive_time:
            self.next_print_receive_time = time.time() + 60

            self.send_log({'data_server/dataserver_recv_instance_per_min': self.recv_training_instance_count})
            self.send_log({"data_server/dataserver_socket_time_per_minutes": sum(self.socket_time_list)})
            if len(self.sampling_time_list) > 0:
                self.send_log({"data_server/dataserver_sampling_time_per_min": sum(self.sampling_time_list)})

            self.log_handler.info(
                "mpi %d, local rank %d,  gpu_parse_data_time_per_minutes %f, receive_instance %d socket time %f" % (
                    self.global_rank, self.data_server_local_rank, sum(self.parse_data_time_list),
                    self.recv_training_instance_count, sum(self.socket_time_list)))

            self.parse_data_time_list = []
            self.recv_training_instance_count = 0
            self.socket_time_list = []
            self.sampling_time_list = []

    def sampling_data(self):
        if self.global_rank == 0:
            self.log_handler.info("start sampling")

        s_time = time.time()

        slice_data = self.training_set.slice(self.batch_size)

        if self.global_rank == 0:
            self.log_handler.info(
                "slice time %f, batch size %d, pool size %d" % (
                    time.time() - s_time, self.batch_size, self.training_set.len()))

        pa_data = pa.serialize(slice_data).to_buffer()
        del slice_data

        self.plasma_client.put(pa_data, self.plasma_data_id, memcopy_threads=16)

    def receive_start_training(self, socks):
        if self.global_rank == 0 and self.data_server_local_rank == 0:
            return
        if self.root_subscriber in socks and socks[self.root_subscriber] == zmq.POLLIN:
            data = self.root_subscriber.recv_string(zmq.NOBLOCK)
            self.start_training = True
            self.log_handler.info("rank %d recv start training info" % self.global_rank)

    def run(self):
        self.log_handler.info("data server rank %d, data server local rank %d, port %d start running" % (
            self.global_rank, self.data_server_local_rank, self.server_port))
        self.log_handler.info("plasma id list: {}".format(self.plasma_data_id))

        while True:
            sockets = dict(self.poller.poll(timeout=100))

            self.receive_data(sockets)

            self.receive_start_training(sockets)

            if self.global_rank == 0 and self.data_server_local_rank == 0:
                if self.training_set.len() >= self.training_set.max_capacity and self.start_training is False:
                    self.start_training = True
                    self.root_publisher.send_string("start")

            if self.start_training and not self.plasma_client.contains(self.plasma_data_id) \
                    and time.time() > self.next_sampling_time and \
                    self.training_set.len() >= self.training_set.max_capacity:
                s_time = time.time()
                self.sampling_data()
                self.sampling_time_list.append(time.time() - s_time)
                self.next_sampling_time = time.time() + self.data_server_sampling_interval
                self.log_handler.info("rank %d, %d sampling data" % (self.global_rank, self.data_server_local_rank))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--rank', default=0, type=int, help='rank of current process')
    parser.add_argument('--data_server_local_rank', default=0, type=int, help='data_server_local_rank')
    parser.add_argument('--world_size', default=2, type=int, help="world size")
    parser.add_argument('--model_side', default="self", help="model_side")
    parser.add_argument('--test_mode', default=0, type=int, help="model_side")
    args = parser.parse_args()

    data_server = GpuDataServer(args)
    data_server.run()
