import json
import os
import tempfile
import unittest

from launch import *


class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
    """wrap_udf_in_torch_dist_launcher()"""

    def test_simple(self):
        # test that a simple udf_command is correctly wrapped
        udf_command = "python3.7 path/to/some/trainer.py arg1 arg2"
        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=2,
            num_nodes=2,
            node_rank=1,
            master_addr="127.0.0.1",
            master_port=1234,
        )
        expected = (
            "python3.7 -m torch.distributed.launch "
            "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
            "--master_port=1234 path/to/some/trainer.py arg1 arg2"
        )
        self.assertEqual(wrapped_udf_command, expected)

    def test_chained_udf(self):
        # test that a chained udf_command is properly handled
        udf_command = (
            "cd path/to && python3.7 path/to/some/trainer.py arg1 arg2"
        )
        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=2,
            num_nodes=2,
            node_rank=1,
            master_addr="127.0.0.1",
            master_port=1234,
        )
        expected = (
            "cd path/to && python3.7 -m torch.distributed.launch "
            "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
            "--master_port=1234 path/to/some/trainer.py arg1 arg2"
        )
        self.assertEqual(wrapped_udf_command, expected)

    def test_py_versions(self):
        # test that this correctly handles different py versions/binaries
        py_binaries = (
            "python3.7",
            "python3.8",
            "python3.9",
            "python3",
            "python",
        )
        udf_command = "{python_bin} path/to/some/trainer.py arg1 arg2"

        for py_bin in py_binaries:
            wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
                udf_command=udf_command.format(python_bin=py_bin),
                num_trainers=2,
                num_nodes=2,
                node_rank=1,
                master_addr="127.0.0.1",
                master_port=1234,
            )
            expected = (
                "{python_bin} -m torch.distributed.launch ".format(
                    python_bin=py_bin
                )
                + "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
                "--master_port=1234 path/to/some/trainer.py arg1 arg2"
            )
            self.assertEqual(wrapped_udf_command, expected)


class TestWrapCmdWithLocalEnvvars(unittest.TestCase):
    """wrap_cmd_with_local_envvars()"""

    def test_simple(self):
        self.assertEqual(
            wrap_cmd_with_local_envvars("ls && pwd", "VAR1=value1 VAR2=value2"),
            "(export VAR1=value1 VAR2=value2; ls && pwd)",
        )


class TestConstructDglServerEnvVars(unittest.TestCase):
    """construct_dgl_server_env_vars()"""

    def test_simple(self):
        self.assertEqual(
            construct_dgl_server_env_vars(
                num_samplers=2,
                num_server_threads=3,
                tot_num_clients=4,
                part_config="path/to/part.config",
                ip_config="path/to/ip.config",
                num_servers=5,
                graph_format="csc",
                keep_alive=False,
            ),
            (
                "DGL_ROLE=server "
                "DGL_NUM_SAMPLER=2 "
                "OMP_NUM_THREADS=3 "
                "DGL_NUM_CLIENT=4 "
                "DGL_CONF_PATH=path/to/part.config "
                "DGL_IP_CONFIG=path/to/ip.config "
                "DGL_NUM_SERVER=5 "
                "DGL_GRAPH_FORMAT=csc "
                "DGL_KEEP_ALIVE=0 "
            ),
        )


class TestConstructDglClientEnvVars(unittest.TestCase):
    """construct_dgl_client_env_vars()"""

    def test_simple(self):
        # with pythonpath
        self.assertEqual(
            construct_dgl_client_env_vars(
                num_samplers=1,
                tot_num_clients=2,
                part_config="path/to/part.config",
                ip_config="path/to/ip.config",
                num_servers=3,
                graph_format="csc",
                num_omp_threads=4,
                group_id=0,
                pythonpath="some/pythonpath/",
            ),
            (
                "DGL_DIST_MODE=distributed "
                "DGL_ROLE=client "
                "DGL_NUM_SAMPLER=1 "
                "DGL_NUM_CLIENT=2 "
                "DGL_CONF_PATH=path/to/part.config "
                "DGL_IP_CONFIG=path/to/ip.config "
                "DGL_NUM_SERVER=3 "
                "DGL_GRAPH_FORMAT=csc "
                "OMP_NUM_THREADS=4 "
                "DGL_GROUP_ID=0 "
                "PYTHONPATH=some/pythonpath/ "
            ),
        )
        # without pythonpath
        self.assertEqual(
            construct_dgl_client_env_vars(
                num_samplers=1,
                tot_num_clients=2,
                part_config="path/to/part.config",
                ip_config="path/to/ip.config",
                num_servers=3,
                graph_format="csc",
                num_omp_threads=4,
                group_id=0,
            ),
            (
                "DGL_DIST_MODE=distributed "
                "DGL_ROLE=client "
                "DGL_NUM_SAMPLER=1 "
                "DGL_NUM_CLIENT=2 "
                "DGL_CONF_PATH=path/to/part.config "
                "DGL_IP_CONFIG=path/to/ip.config "
                "DGL_NUM_SERVER=3 "
                "DGL_GRAPH_FORMAT=csc "
                "OMP_NUM_THREADS=4 "
                "DGL_GROUP_ID=0 "
            ),
        )


def test_submit_jobs():
    class Args:
        pass

    args = Args()

    with tempfile.TemporaryDirectory() as test_dir:
        num_machines = 8
        ip_config = os.path.join(test_dir, "ip_config.txt")
        with open(ip_config, "w") as f:
            for i in range(num_machines):
                f.write("{} {}\n".format("127.0.0." + str(i), 30050))
        part_config = os.path.join(test_dir, "ogb-products.json")
        with open(part_config, "w") as f:
            json.dump({"num_parts": num_machines}, f)
        args.num_trainers = 8
        args.num_samplers = 1
        args.num_servers = 4
        args.workspace = test_dir
        args.part_config = "ogb-products.json"
        args.ip_config = "ip_config.txt"
        args.server_name = "ogb-products"
        args.keep_alive = False
        args.num_server_threads = 1
        args.graph_format = "csc"
        args.extra_envs = ["NCCL_DEBUG=INFO"]
        args.num_omp_threads = 1
        udf_command = "python3 train_dist.py --num_epochs 10"
        clients_cmd, servers_cmd = submit_jobs(args, udf_command, dry_run=True)

        def common_checks():
            assert "cd " + test_dir in cmd
            assert "export " + args.extra_envs[0] in cmd
            assert f"DGL_NUM_SAMPLER={args.num_samplers}" in cmd
            assert (
                f"DGL_NUM_CLIENT={args.num_trainers*(args.num_samplers+1)*num_machines}"
                in cmd
            )
            assert f"DGL_CONF_PATH={args.part_config}" in cmd
            assert f"DGL_IP_CONFIG={args.ip_config}" in cmd
            assert f"DGL_NUM_SERVER={args.num_servers}" in cmd
            assert f"DGL_GRAPH_FORMAT={args.graph_format}" in cmd
            assert f"OMP_NUM_THREADS={args.num_omp_threads}" in cmd
            assert udf_command[len("python3 ") :] in cmd

        for cmd in clients_cmd:
            common_checks()
            assert "DGL_DIST_MODE=distributed" in cmd
            assert "DGL_ROLE=client" in cmd
            assert "DGL_GROUP_ID=0" in cmd
            assert (
                f"python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
                in cmd
            )
            assert "--master_addr=127.0.0" in cmd
            assert "--master_port=1234" in cmd
        for cmd in servers_cmd:
            common_checks()
            assert "DGL_ROLE=server" in cmd
            assert "DGL_KEEP_ALIVE=0" in cmd
            assert "DGL_SERVER_ID=" in cmd


if __name__ == "__main__":
    unittest.main()
