# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import ray
import torch
from verl import DataProto
from tensordict import TensorDict

from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs
from verl.single_controller.ray import RayWorkerGroup

os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN'


@ray.remote
class ModelActor(Worker):

    def __init__(self):
        pass


class HackSelf():

    def __init__(self):
        pass


def get_aux_metrics(self, test_proto):
    sequence_ids = test_proto.batch["sequence_ids"]
    decode_count = []
    for i in range(sequence_ids.size(0)):
        decode_count.append(len(sequence_ids[i].tolist()))
    ret_proto = DataProto(batch=TensorDict({
        "sequence_ids": sequence_ids,
        "decode_count": torch.tensor(decode_count)
    },
                                           batch_size=sequence_ids.size(0)))
    return ret_proto


def test():
    # construct model
    ray.init()

    # create 2 workers, each hold a GPU
    resource_pool = RayResourcePool([2], use_gpu=True, name_prefix='a')

    class_with_args = RayClassWithInitArgs(cls=ModelActor)
    shard_wg = RayWorkerGroup(resource_pool, class_with_args)

    test_bs = 8
    test_proto = DataProto(TensorDict({
        "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
    },
                                      batch_size=test_bs),
                           meta_info={"query_length": 1536})

    # Sharding among different ranks
    ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)

    # compare execute on driver
    hs = HackSelf()
    ret_proto2 = get_aux_metrics(hs, test_proto)

    torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])

    ray.shutdown()
