import multiprocessing as mp
import os
import socket
import time
import unittest

import backend as F
import pytest
from numpy.testing import assert_array_equal
from utils import generate_ip_config, reset_envs

import dgl

if os.name != "nt":
    import fcntl
    import struct

INTEGER = 2
STR = "hello world!"
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())


def foo(x, y):
    assert x == 123
    assert y == "abc"


class MyRequest(dgl.distributed.Request):
    def __init__(self):
        self.x = 123
        self.y = "abc"
        self.z = F.randn((3, 4))
        self.foo = foo

    def __getstate__(self):
        return self.x, self.y, self.z, self.foo

    def __setstate__(self, state):
        self.x, self.y, self.z, self.foo = state

    def process_request(self, server_state):
        pass


class MyResponse(dgl.distributed.Response):
    def __init__(self):
        self.x = 432

    def __getstate__(self):
        return self.x

    def __setstate__(self, state):
        self.x = state


def simple_func(tensor):
    return tensor


class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state


class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res


TIMEOUT_SERVICE_ID = 123456789
TIMEOUT_META = "timeout_test"


class TimeoutResponse(dgl.distributed.Response):
    def __init__(self, meta):
        self.meta = meta

    def __getstate__(self):
        return self.meta

    def __setstate__(self, state):
        self.meta = state


class TimeoutRequest(dgl.distributed.Request):
    def __init__(self, meta, timeout, response=True):
        self.meta = meta
        self.timeout = timeout
        self.response = response

    def __getstate__(self):
        return self.meta, self.timeout, self.response

    def __setstate__(self, state):
        self.meta, self.timeout, self.response = state

    def process_request(self, server_state):
        assert self.meta == TIMEOUT_META
        # convert from milliseconds to seconds
        time.sleep(self.timeout / 1000)
        if not self.response:
            return None
        res = TimeoutResponse(self.meta)
        return res


def start_server(
    num_clients,
    ip_config,
    server_id=0,
    keep_alive=False,
    num_servers=1,
    net_type="tensorpipe",
):
    print("Sleep 1 seconds to test client re-connect.")
    time.sleep(1)
    server_state = dgl.distributed.ServerState(
        None, local_g=None, partition_book=None, keep_alive=keep_alive
    )
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
    dgl.distributed.register_service(
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
    print("Start server {}".format(server_id))
    dgl.distributed.start_server(
        server_id=server_id,
        ip_config=ip_config,
        num_servers=num_servers,
        num_clients=num_clients,
        server_state=server_state,
        net_type=net_type,
    )


def start_client(ip_config, group_id=0, num_servers=1, net_type="tensorpipe"):
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
    dgl.distributed.connect_to_server(
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
        net_type=net_type,
    )
    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call(target_and_requests)
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test send_request_to_machine
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response()
    assert res.hello_str == STR
    assert res.integer == INTEGER
    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
    # test remote_call_to_machine
    target_and_requests = []
    for i in range(10):
        target_and_requests.append((0, req))
    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
    for res in res_list:
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))


def start_client_timeout(
    ip_config, group_id=0, num_servers=1, net_type="tensorpipe"
):
    dgl.distributed.register_service(
        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse
    )
    dgl.distributed.connect_to_server(
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
        net_type=net_type,
    )
    timeout = 1 * 1000  # milliseconds
    req = TimeoutRequest(TIMEOUT_META, timeout)
    # test send and recv
    dgl.distributed.send_request(0, req)
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call(
            target_and_requests, timeout=int(timeout / 2)
        )
    except dgl.DGLError:
        expect_except = True
    assert expect_except
    # test send_request_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout)
    dgl.distributed.send_request_to_machine(0, req)
    res = dgl.distributed.recv_response(timeout=int(timeout / 2))
    assert res is None
    res = dgl.distributed.recv_response()
    assert res.meta == TIMEOUT_META
    # test remote_call_to_machine
    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)
    target_and_requests = []
    for i in range(3):
        target_and_requests.append((0, req))
    expect_except = False
    try:
        res_list = dgl.distributed.remote_call_to_machine(
            target_and_requests, timeout=int(timeout / 2)
        )
    except dgl.DGLError:
        expect_except = True
    assert expect_except


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
def test_rpc_timeout(net_type):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    ip_config = "rpc_ip_config.txt"
    generate_ip_config(ip_config, 1, 1)
    ctx = mp.get_context("spawn")
    pserver = ctx.Process(
        target=start_server, args=(1, ip_config, 0, False, 1, net_type)
    )
    pclient = ctx.Process(
        target=start_client_timeout, args=(ip_config, 0, 1, net_type)
    )
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()


def test_serialize():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        deserialize_from_payload,
        serialize_to_payload,
    )

    SERVICE_ID = 12345
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    req1 = deserialize_from_payload(MyRequest, data, tensors)
    req1.foo(req1.x, req1.y)
    assert req.x == req1.x
    assert req.y == req1.y
    assert F.array_equal(req.z, req1.z)

    res = MyResponse()
    data, tensors = serialize_to_payload(res)
    res1 = deserialize_from_payload(MyResponse, data, tensors)
    assert res.x == res1.x


def test_rpc_msg():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    from dgl.distributed.rpc import (
        RPCMessage,
        deserialize_from_payload,
        serialize_to_payload,
    )

    SERVICE_ID = 32452
    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
    req = MyRequest()
    data, tensors = serialize_to_payload(req)
    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
    assert rpcmsg.service_id == SERVICE_ID
    assert rpcmsg.msg_seq == 23
    assert rpcmsg.client_id == 0
    assert rpcmsg.server_id == 1
    assert len(rpcmsg.data) == len(data)
    assert len(rpcmsg.tensors) == 1
    assert F.array_equal(rpcmsg.tensors[0], req.z)


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["tensorpipe"])
def test_rpc(net_type):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    generate_ip_config("rpc_ip_config.txt", 1, 1)
    ctx = mp.get_context("spawn")
    pserver = ctx.Process(
        target=start_server,
        args=(1, "rpc_ip_config.txt", 0, False, 1, net_type),
    )
    pclient = ctx.Process(
        target=start_client, args=("rpc_ip_config.txt", 0, 1, net_type)
    )
    pserver.start()
    pclient.start()
    pserver.join()
    pclient.join()


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
def test_multi_client(net_type):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
    ctx = mp.get_context("spawn")
    num_clients = 20
    pserver = ctx.Process(
        target=start_server,
        args=(num_clients, ip_config, 0, False, 1, net_type),
    )
    pclient_list = []
    for i in range(num_clients):
        pclient = ctx.Process(
            target=start_client, args=(ip_config, 0, 1, net_type)
        )
        pclient_list.append(pclient)
    pserver.start()
    for i in range(num_clients):
        pclient_list[i].start()
    for i in range(num_clients):
        pclient_list[i].join()
    pserver.join()


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
def test_multi_thread_rpc(net_type):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    num_servers = 2
    ip_config = "rpc_ip_config_multithread.txt"
    generate_ip_config(ip_config, num_servers, num_servers)
    ctx = mp.get_context("spawn")
    pserver_list = []
    for i in range(num_servers):
        pserver = ctx.Process(
            target=start_server, args=(1, ip_config, i, False, 1, net_type)
        )
        pserver.start()
        pserver_list.append(pserver)

    def start_client_multithread(ip_config):
        import threading

        dgl.distributed.connect_to_server(
            ip_config=ip_config, num_servers=1, net_type=net_type
        )
        dgl.distributed.register_service(
            HELLO_SERVICE_ID, HelloRequest, HelloResponse
        )

        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
        dgl.distributed.send_request(0, req)

        def subthread_call(server_id):
            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
            dgl.distributed.send_request(server_id, req)

        subthread = threading.Thread(target=subthread_call, args=(1,))
        subthread.start()
        subthread.join()

        res0 = dgl.distributed.recv_response()
        res1 = dgl.distributed.recv_response()
        # Order is not guaranteed
        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))
        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))
        dgl.distributed.exit_client()

    start_client_multithread(ip_config)
    pserver.join()


@unittest.skipIf(
    True,
    reason="Tests of multiple groups may fail and let's disable them for now.",
)
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_multi_client_groups():
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    ip_config = "rpc_ip_config_mul_client_groups.txt"
    num_machines = 5
    # should test with larger number but due to possible port in-use issue.
    num_servers = 1
    generate_ip_config(ip_config, num_machines, num_servers)
    # presssue test
    num_clients = 2
    num_groups = 2
    ctx = mp.get_context("spawn")
    pserver_list = []
    for i in range(num_servers * num_machines):
        pserver = ctx.Process(
            target=start_server,
            args=(num_clients, ip_config, i, True, num_servers),
        )
        pserver.start()
        pserver_list.append(pserver)
    pclient_list = []
    for i in range(num_clients):
        for group_id in range(num_groups):
            pclient = ctx.Process(
                target=start_client, args=(ip_config, group_id, num_servers)
            )
            pclient.start()
            pclient_list.append(pclient)
    for p in pclient_list:
        p.join()
    for p in pserver_list:
        assert p.is_alive()
    # force shutdown server
    dgl.distributed.shutdown_servers(ip_config, num_servers)
    for p in pserver_list:
        p.join()


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["socket", "tensorpipe"])
def test_multi_client_connect(net_type):
    reset_envs()
    os.environ["DGL_DIST_MODE"] = "distributed"
    ip_config = "rpc_ip_config_mul_client.txt"
    generate_ip_config(ip_config, 1, 1)
    ctx = mp.get_context("spawn")
    num_clients = 1
    pserver = ctx.Process(
        target=start_server,
        args=(num_clients, ip_config, 0, False, 1, net_type),
    )

    # small max try times
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1"
    expect_except = False
    try:
        start_client(ip_config, 0, 1, net_type)
    except dgl.distributed.DistConnectError as err:
        print("Expected error: {}".format(err))
        expect_except = True
    assert expect_except

    # large max try times
    os.environ["DGL_DIST_MAX_TRY_TIMES"] = "1024"
    pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
    pclient.start()
    pserver.start()
    pclient.join()
    pserver.join()
    reset_envs()


if __name__ == "__main__":
    test_serialize()
    test_rpc_msg()
    test_rpc()
    test_multi_client("socket")
    test_multi_client("tesnsorpipe")
    test_multi_thread_rpc()
    test_multi_client_connect("socket")
    test_multi_client_connect("tensorpipe")
