import json
import argparse
import os
import time


def start_one_machine(machine_ip, gpu_num, cur_rank, world_size, model_side, root_address, data_server_local_num):
    plasma_path = os.path.join(os.path.expanduser("~"), "plasma")
    cur_command = "plasma_store -s {} -m 15000000000 &".format(plasma_path)
    os.system(cur_command)
    time.sleep(2)

    start_gpu_id = 0
    cuda_env_setting = "CUDA_VISIBLE_DEVICES=" + ",".join(map(str, range(start_gpu_id, start_gpu_id + gpu_num)))
    for i in range(gpu_num):
        for data_i in range(data_server_local_num):
            cur_command = cuda_env_setting + " nohup python learner/gpu_data_server.py "
            cur_command += "--rank %d --world_size %d --data_server_local_rank %d" % (cur_rank + i, world_size, data_i)
            cur_command += " > {} 2>&1 &".format(os.path.join(log_dir, "data.log"))
            os.system(cur_command)
        cur_command = cuda_env_setting + " nohup python learner/gpu_learner_server.py "
        cur_command += "--ip %s --init_method %s --rank %d --world_size %d --model_side %s" % (
            machine_ip, root_address, cur_rank + i, world_size, model_side)
        cur_command += " > {} 2>&1 &".format(os.path.join(log_dir, "learner.log"))
        os.system(cur_command)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_side', default="self", help="model_side")
    parser.add_argument("--machine_index", default=0, type=int, help="machine index in machines")
    args = parser.parse_args()

    with open('config.json')as f:
        json_str = f.read()
    json_config = json.loads(json_str)

    log_dir = json_config["log_dir"]
    os.makedirs(log_dir, exist_ok=True)

    unable_cuda_devices = "CUDA_VISIBLE_DEVICES=-1"
    machine_config = json_config["learner_machine_ips"][args.model_side]
    if args.machine_index == 0:
        # log server
        cur_command = unable_cuda_devices + " nohup python learner/log_server.py"
        cur_command += " > {} 2>&1 &".format(os.path.join(log_dir, "log.log"))
        os.system(cur_command)
        # config server
        cur_command = unable_cuda_devices + " nohup python learner/config_server.py"
        cur_command += " > {} 2>&1 &".format(os.path.join(log_dir, "config.log"))
        os.system(cur_command)
        # tensorboard
        cur_command = unable_cuda_devices + " host=`ip route get 1 | awk '{print $NF;exit}'`;"
        cur_command += "nohup python -m tensorboard.main --logdir=./{} --host=127.0.0.1 --port 8081".format(log_dir)
        cur_command += " > /dev/null 2>&1 &"
        os.system(cur_command)

    gpu_num = json_config["gpu_num_per_machine"]
    cur_rank = args.machine_index * gpu_num
    world_size = len(machine_config["machines"]) * gpu_num
    data_server_local_num = json_config["data_server_to_learner_num"]
    m_ip = machine_config["machines"][args.machine_index]

    start_one_machine(m_ip, gpu_num, cur_rank, world_size, args.model_side,
                      machine_config["root_address"], data_server_local_num)
