# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

import hydra
import ray
from omegaconf import OmegaConf

from verl.trainer.ppo.utils import need_reference_policy
from verl.utils.config import validate_config

from .prime_ray_trainer import RayPRIMETrainer


@hydra.main(config_path="config", config_name="prime_trainer", version_base=None)
def main(config):
    run_prime(config)


def run_prime(config, compute_score=None):
    if not ray.is_initialized():
        default_runtime_env = {
            "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}
        }
        ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
        runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
        ray_init_kwargs = OmegaConf.create(
            {**ray_init_kwargs, "runtime_env": runtime_env}
        )
        print(f"ray init kwargs: {ray_init_kwargs}")
        # this is for local ray cluster
        ray.init(**OmegaConf.to_container(ray_init_kwargs))

    ray.get(main_task.remote(config, compute_score))


@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
    # print initial config
    from pprint import pprint

    from omegaconf import OmegaConf

    from verl.utils.fs import copy_local_path_from_hdfs

    pprint(
        OmegaConf.to_container(config, resolve=True)
    )  # resolve=True will eval symbol values
    OmegaConf.resolve(config)

    # whether use critic or not
    if ("gae" not in config.algorithm.adv_estimator) and (
        "wogroup" not in config.algorithm.adv_estimator
    ):
        use_critic = False
    else:
        use_critic = True

    # define worker classes
    if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
        assert config.critic.strategy in {"fsdp", "fsdp2"}
        from verl.single_controller.ray import RayWorkerGroup
        from verl.workers.fsdp_workers import (
            ActorRolloutRefWorker,
            AsyncActorRolloutRefWorker,
        )

        actor_rollout_cls = (
            AsyncActorRolloutRefWorker
            if config.actor_rollout_ref.rollout.mode == "async"
            else ActorRolloutRefWorker
        )
        ray_worker_group_cls = RayWorkerGroup

    elif config.actor_rollout_ref.actor.strategy == "megatron":
        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
        from verl.single_controller.ray import RayWorkerGroup
        from verl.workers.megatron_workers import (
            ActorRolloutRefWorker,
            AsyncActorRolloutRefWorker,
        )

        actor_rollout_cls = (
            AsyncActorRolloutRefWorker
            if config.actor_rollout_ref.rollout.mode == "async"
            else ActorRolloutRefWorker
        )
        ray_worker_group_cls = RayWorkerGroup

    else:
        raise NotImplementedError

    from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role

    role_worker_mapping = {
        Role.ActorRollout: ray.remote(actor_rollout_cls),
    }

    global_pool_id = "global_pool"
    resource_pool_spec = {
        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
    }
    mapping = {
        Role.ActorRollout: global_pool_id,
    }

    # use reference model
    if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
        role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
        mapping[Role.RefPolicy] = global_pool_id

    if config.reward_model.enable:
        from .prime_fsdp_workers import PRIMERewardModelWorker

        role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
        mapping[Role.RewardModel] = global_pool_id

    if use_critic:
        from verl.workers.fsdp_workers import CriticWorker

        role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
        mapping[Role.Critic] = global_pool_id

    # validate config
    # TODO: Additional config checks can be added with proper function under prime recipe
    validate_config(
        config=config,
        use_reference_policy=need_reference_policy(role_worker_mapping),
        use_critic=use_critic,
    )

    # download the checkpoint from hdfs
    local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)

    # instantiate tokenizer
    from verl.utils import hf_tokenizer

    tokenizer = hf_tokenizer(local_path)
    reward_manager_name = config.reward_model.get("reward_manager", "naive")
    if reward_manager_name == "naive":
        from verl.workers.reward_manager import NaiveRewardManager

        reward_manager_cls = NaiveRewardManager
    elif reward_manager_name == "prime":
        from verl.workers.reward_manager import PrimeRewardManager

        reward_manager_cls = PrimeRewardManager
    elif reward_manager_name == "collabllm":
        from verl.workers.reward_manager import CollabLLMRewardManager

        reward_manager_cls = CollabLLMRewardManager
    elif reward_manager_name == "collabllm_judge":
        from verl.workers.reward_manager import CollabLLMJudgeRewardManager

        reward_manager_cls = CollabLLMJudgeRewardManager
    elif reward_manager_name == "medical":
        from verl.workers.reward_manager import MedicalRewardManager

        reward_manager_cls = MedicalRewardManager
    elif reward_manager_name == "medical_judge":
        from verl.workers.reward_manager import MedicalJudgeRewardManager

        reward_manager_cls = MedicalJudgeRewardManager
    else:
        raise NotImplementedError
    reward_fn = reward_manager_cls(
        tokenizer=tokenizer,
        num_examine=0,
        compute_score=compute_score,
        **config.reward_model.get("reward_kwargs", {}),
    )

    # Note that we always use function-based RM for validation
    val_reward_fn = reward_manager_cls(
        tokenizer=tokenizer,
        num_examine=1,
        compute_score=compute_score,
        **config.reward_model.get("reward_kwargs", {}),
    )

    resource_pool_manager = ResourcePoolManager(
        resource_pool_spec=resource_pool_spec, mapping=mapping
    )

    trainer = RayPRIMETrainer(
        config=config,
        tokenizer=tokenizer,
        role_worker_mapping=role_worker_mapping,
        resource_pool_manager=resource_pool_manager,
        ray_worker_group_cls=ray_worker_group_cls,
        reward_fn=reward_fn,
        val_reward_fn=val_reward_fn,
    )
    trainer.init_workers()
    trainer.fit()


if __name__ == "__main__":
    main()
