

import time

import ray

from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool

@ray.remote
class TestActor(Worker):

    def __init__(self, cuda_visible_devices=None) -> None:
        super().__init__(cuda_visible_devices)

    def get_node_id(self):
        return ray.get_runtime_context().get_node_id()

def test():
    ray.init()

    print("test single-node-no-partition")
    resource_pool = RayResourcePool([8], use_gpu=True)

    class_with_args = RayClassWithInitArgs(cls=TestActor)

    print("create actor worker group")
    actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor")
    print("create critic worker group")
    critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic")
    print("create rm worker group")
    rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm")
    print("create ref worker group")
    ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref")

    assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
    assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
    assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
    assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]

    del actor_wg
    del critic_wg
    del rm_wg
    del ref_wg

    [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
    print("wait 5s to remove placemeng_group")
    time.sleep(5)

    print("test single-node-multi-partition")
    rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
    ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
    total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)

    assert rm_resource_pool.world_size == 4
    assert ref_resource_pool.world_size == 4
    assert total_resource_pool.world_size == 8

    actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor")
    critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic")
    rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm")
    ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref")

    assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
    assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
    assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
    assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]

    ray.shutdown()
