

import ray
import torch

from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup

def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
    for arg in args:
        assert len(arg) == 2
        for i in range(worker_group.world_size - 2):
            arg.append(arg[i % 2])
    for k, v in kwargs.items():
        assert len(v) == 2
        for i in range(worker_group.world_size - 2):
            v.append(v[i % 2])
    return args, kwargs

@ray.remote
class TestActor(Worker):

    def __init__(self, x) -> None:
        super().__init__()
        self._x = x

    def foo(self, y):
        return self._x + y

    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
    def foo_rank_zero(self, x, y):
        return self._x + y + x

    @register(Dispatch.ONE_TO_ALL, blocking=False)
    def foo_one_to_all(self, x, y):
        return self._x + y + x

    @register(Dispatch.ALL_TO_ALL, blocking=False)
    def foo_all_to_all(self, x, y):
        return self._x + y + x

    @register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all})
    def foo_custom(self, x, y):
        return self._x + y + x

@ray.remote(num_gpus=0.1)
def remote_call_wg(worker_names):
    class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
    worker_group = RayWorkerGroup.from_detached(
        worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None
    )
    print(worker_group.worker_names)

    output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])
    assert output_ref == [8, 10, 8, 10]

    output_ref = worker_group.foo_rank_zero(x=1, y=2)
    assert output_ref == 5

    return worker_group.worker_names

def add_one(data):
    data = data.to("cuda")
    data += 1
    data = data.to("cpu")
    return data

def test_basics():
    ray.init(num_cpus=100)

    resource_pool = RayResourcePool([4], use_gpu=True)
    class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)

    worker_group = RayWorkerGroup(
        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic"
    )

    print(worker_group.worker_names)

    output = worker_group.execute_all_sync("foo", y=3)
    assert output == [5, 5, 5, 5]

    output_ref = worker_group.execute_all_async("foo", y=4)
    print(output_ref)

    assert ray.get(output_ref) == [6, 6, 6, 6]

    output_ref = worker_group.foo_one_to_all(x=1, y=2)
    assert ray.get(output_ref) == [5, 5, 5, 5]

    output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8])
    assert ray.get(output_ref) == [8, 10, 12, 14]

    print(ray.get(remote_call_wg.remote(worker_group.worker_names)))

    output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2))
    torch.testing.assert_close(output, torch.ones(2, 2) + 1)

    ray.shutdown()

if __name__ == "__main__":
    test_basics()
