# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import ray
import torch
from tensordict import TensorDict

from verl import DataProto
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray import RayWorkerGroup
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool

os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["NCCL_DEBUG"] = "WARN"


@ray.remote
class ModelActor(Worker):
    def __init__(self):
        pass


class HackSelf:
    def __init__(self):
        pass


def get_aux_metrics(self, test_proto):
    sequence_ids = test_proto.batch["sequence_ids"]
    decode_count = []
    for i in range(sequence_ids.size(0)):
        decode_count.append(len(sequence_ids[i].tolist()))
    ret_proto = DataProto(
        batch=TensorDict(
            {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)
        )
    )
    return ret_proto


def test():
    # construct model
    ray.init()

    # create 2 workers, each hold a GPU
    resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a")

    class_with_args = RayClassWithInitArgs(cls=ModelActor)
    shard_wg = RayWorkerGroup(resource_pool, class_with_args)

    test_bs = 8
    test_proto = DataProto(
        TensorDict(
            {
                "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
            },
            batch_size=test_bs,
        ),
        meta_info={"query_length": 1536},
    )

    # Sharding among different ranks
    ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)

    # compare execute on driver
    hs = HackSelf()
    ret_proto2 = get_aux_metrics(hs, test_proto)

    torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])

    ray.shutdown()
