import sys
import os
sys.path.append("./")
# print(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import zmq
import time
import pickle
import argparse
import json
import pyarrow as pa
import pyarrow.plasma as plasma
import torch as th
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.distributions import Categorical
from worker.instance import to_device
from learner.utils import setup_logger
from learner.basic_server import BasicServer
from learner.model_utils import create_models, load_model, serialize_model
from learner.model_utils import convert_to_pt, load_model_from_p2p


class GpuLeanerServer(BasicServer):
    def __init__(self, args):
        BasicServer.__init__(self)
        self.global_rank = args.rank
        self.world_size = args.world_size
        self.self_play_side = args.model_side
        self.test_mode = args.test_mode

        self.local_rank = self.global_rank % self.json_config["gpu_num_per_machine"]
        log_path = os.path.join(self.log_dir, args.ip, 'gpu_learner_log/%d' % self.local_rank)
        self.log_handler = setup_logger("GpuLeanerServer", log_path)

        self.gpu_num_per_machine = self.json_config["gpu_num_per_machine"]
        self.traj_len = self.json_config["traj_len"]
        self.learning_rate = self.json_config["lr"]
        self.clip_epsilon = self.json_config["ppo_clip"]
        self.dual_clip_c = self.json_config["dual_clip_c"]
        self.ent_c = self.json_config["ent"]
        self.grad_clip = self.json_config["grad_clip"]
        self.latest_update_interval = self.json_config["latest_update_interval"]

        self.model_save_path_prefix = self.json_config["p2p_path"] + self.json_config["p2p_filename"]
        self.model_save_url_prefix = self.json_config["p2p_url"] + self.json_config["p2p_filename"]

        self.net = create_models(self.json_config["state_dim"], self.json_config["action_dim"])

        if self.global_rank == 0 and self.json_config["load_model"] is True:
            if self.json_config["model_path"]:
                self.net = load_model(self.json_config["model_path"])
                self.log_handler.info("load init model from {}".format(self.json_config["model_path"]))
            else:
                self.net, file_name = load_model_from_p2p(self.model_save_path_prefix)
                self.log_handler.info("load init model from {}".format(file_name))

        if self.test_mode:
            dist.init_process_group(init_method=args.init_method,
                                    backend='gloo',
                                    rank=self.global_rank,
                                    world_size=self.world_size)
            self.net = DDP(self.net)
        else:
            dist.init_process_group(init_method=args.init_method,
                                    backend='nccl',
                                    rank=self.global_rank,
                                    world_size=self.world_size)
            self.net.to(self.local_rank).train()
            self.net = DDP(self.net, device_ids=[self.local_rank])
        th.manual_seed(0)
        self.optimizer_policy = th.optim.Adam(self.net.parameters(), lr=self.learning_rate)

        if self.global_rank == 0:
            if self.json_config["p2p_path"] is None:
                self.log_handler.error('p2p path is null')
                exit(0)
            path = self.json_config["p2p_path"]
            if not os.path.exists(path):
                os.makedirs(path)
            else:
                if self.json_config["load_model"] is not True:
                    os.system(
                        'rm ' + os.path.join(self.json_config["p2p_path"], self.json_config["p2p_filename"]) + '_*')

            self.model_sender = self.context.socket(zmq.PUSH)
            self.model_sender.connect("tcp://%s:%d" % (
                self.json_config["config_server_address"],
                self.json_config["config_server_model_update_port"]))
            self.next_model_update_time = time.time() - 10

            self.log_handler.info("save init parameter")
            self.send_model(0)

        self.plasma_data_id_list = []
        for i in range(self.json_config["data_server_to_learner_num"]):
            plasma_id = plasma.ObjectID(5 * bytes(str(self.global_rank * 10 + i + 1000), encoding="utf-8"))
            self.plasma_data_id_list.append(plasma_id)

        plasma_path = os.path.join(os.path.expanduser("~"), "plasma")
        self.plasma_client = plasma.connect("{}".format(plasma_path), 2)

        self.next_pub_model_time = time.time()
        self.next_check_time = time.time()
        self.wait_data_times = []
        self.sgd_num = 0
        self.total_sgd_count = 0
        self.sgd_total_time_list = []
        self.next_send_model_stat_time = time.time()
        self.next_hot_update_check_time = time.time()

    def send_model(self, iteration_num):
        if time.time() > self.next_model_update_time:
            url_path, disk_path = serialize_model(self.model_save_path_prefix,
                                                  self.model_save_url_prefix,
                                                  self.net.module,
                                                  self.json_config["p2p_cache_size"],
                                                  self.log_handler)

            model_info = {"model_side": self.self_play_side, "url": url_path}
            self.model_sender.send(pickle.dumps(model_info))
            self.next_model_update_time = time.time() + self.latest_update_interval
            self.log_handler.info('send model to config server')

            self.log_handler.info("save parameter done")

    def update(self, training_batch):
        if self.global_rank == 0:
            self.log_handler.info("batch size %d" % len(training_batch["q_values"]))

        avg_q_value = training_batch["q_values"].mean()
        states = training_batch["states"]  # list elements are (b, feature_dim)
        styles = training_batch["styles"].float()  # list elements are (b, feature_dim)
        actions = training_batch["actions"]  # (b, 1)
        q_values = training_batch["q_values"].float()  # (b, 1)
        ad_mean, ad_std = th.mean(training_batch["advantages"]), th.std(training_batch["advantages"])
        advantages = (training_batch["advantages"] - ad_mean) / (ad_std + 1e-9)  # (b, 1)
        advantages = advantages.float()
        fixed_log_probs = training_batch["old_log_prob"].squeeze(1)  # (b, action_head_num)
        fixed_old_state_value = training_batch["old_state_value"]  # (b, 1)

        _, logprobs, values_pred = self.net(states, styles)

        # (b, 1) - (b, 1)
        value_loss = (values_pred - q_values).pow(2).mean()  # scalar

        log_probs = logprobs.gather(1, actions.long())
        fixed_log_probs = fixed_log_probs.gather(1, actions.long())
        probs = th.exp(logprobs)
        dist = Categorical(probs)
        dist_entropy = dist.entropy().mean()
        ratio = th.exp(log_probs - fixed_log_probs)
        surr1 = ratio * advantages
        surr2 = th.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * advantages
        surr = th.min(surr1, surr2)

        # dual clip
        surr3 = th.min(self.dual_clip_c * advantages, th.zeros_like(advantages))
        surr = th.max(surr, surr3)
        psurr = -surr.mean()
        ent = -self.ent_c * dist_entropy

        policy_surr = psurr + ent + 0.5 * value_loss

        # calculate approximate KL divergence
        # see Schulman blog: http://joschu.net/blog/kl-approx.html
        # for approx_kl's varaiance reduction trick, which is still unbiased
        with th.no_grad():
            log_ratio = log_probs - fixed_log_probs  # (b, 1)
            approx_kl = th.mean((th.exp(log_ratio) - 1) - log_ratio)  # scalar
            clipfrac = ((ratio - 1.0).abs() > self.clip_epsilon).float().mean()  # scalar
            ratio_abs = th.mean((ratio - 1.0).abs())

        self.optimizer_policy.zero_grad()
        policy_surr.backward()
        th.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
        self.optimizer_policy.step()

        if time.time() > self.next_send_model_stat_time:
            self.next_send_model_stat_time = time.time() + 3
            self.send_log({"model/entropy_mean": dist_entropy.item()})
            self.send_log({"model/value_loss": value_loss.item()})
            self.send_log({"model/policy_loss": psurr.item()})
            self.send_log({"model/total_loss": policy_surr.item()})
            self.send_log({"model/avg_q_value": avg_q_value.item()})
            self.send_log({"model/advantage": th.mean(training_batch["advantages"]).item()})
            self.send_log({"model/advantage_norm": th.mean(advantages).item()})
            self.send_log({"model/advantage_std": ad_std.item()})
            self.send_log({"model/approx_kl": approx_kl.item()})
            self.send_log({"model/clip_frac": clipfrac.item()})
            self.send_log({"model/ratio": ratio_abs.item()})

    def hot_update(self):
        if time.time() > self.next_hot_update_check_time:
            with open(self.json_config['hot_update_file'], 'r') as f:
                args = json.load(f)
                if args["hot_update"]:
                    self.ent_c = args['ent']

            self.next_hot_update_check_time = time.time() + 60

    def learn(self):
        start_time = time.time()
        # blocking call
        raw_data = self.plasma_client.get(self.plasma_data_id_list[0])
        if self.global_rank == 0:
            self.wait_data_times.append(time.time() - start_time)

        all_data = pa.deserialize(raw_data)
        training_batch = to_device(all_data, self.local_rank)

        if self.global_rank == 0:
            self.log_handler.info(self.plasma_data_id_list[0])
            self.log_handler.info("after deserialize")

        self.update(training_batch)
        self.sgd_num += 1
        self.total_sgd_count += 1

        if self.global_rank == 0:
            self.log_handler.info("after sgd")

        if self.test_mode == 0:
            self.plasma_client.delete([self.plasma_data_id_list[0]])
            del all_data
            id = self.plasma_data_id_list[0]
            del self.plasma_data_id_list[0]
            self.plasma_data_id_list.append(id)

        if self.global_rank == 0:
            self.log_handler.info("after delete")

        # publish new model
        if self.global_rank == 0 and time.time() > self.next_pub_model_time:
            self.send_model(self.total_sgd_count)
            self.next_pub_model_time = time.time() + 10

        if self.global_rank == 0:
            end_time = time.time()
            self.sgd_total_time_list.append(end_time - start_time)
            self.log_handler.info("total training time :%f" % (end_time - start_time))
            self.send_log({'training_avg_time': end_time - start_time})
            self.send_log({'sgd_round_per_min': [1, time.time()]})

    def run(self):
        self.log_handler.info("leaner server rank %d start running" % self.global_rank)
        while True:
            self.hot_update()
            self.learn()
            if time.time() > self.next_check_time:
                self.next_check_time = time.time() + 60
                if self.global_rank == 0:
                    self.send_log({"leaner_server/sgd_round_per_min": self.sgd_num})
                    self.send_log({"leaner_server/wait_data_time_per_min": sum(self.wait_data_times)})
                    self.send_log(
                        {"leaner_server/sgd_total_time": sum(self.sgd_total_time_list) / len(self.sgd_total_time_list)})

                    self.sgd_num = 0
                    self.wait_data_times = []
                    self.sgd_total_time_list = []


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--ip", type=str, help="local machine ip in config file")
    parser.add_argument('--rank', default=0, type=int, help='rank of current process')
    parser.add_argument('--world_size', default=1, type=int, help="world size")
    parser.add_argument('--init_method', default='tcp://127.0.0.1:23456', help="init-method")
    parser.add_argument('--model_side', default="self", help="model_side")
    parser.add_argument('--test_mode', default=0, type=int, help="test_mode")

    args = parser.parse_args()

    learner_server = GpuLeanerServer(args)
    learner_server.run()
