# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional

from verl.workers.rollout.replica import TokenOutput, RolloutReplica

logger = logging.getLogger(__file__)


class Agent_TokenOutput(TokenOutput):
    token_ids: list[int]
    """response token ids"""
    log_probs: Optional[list[float]] = None
    """logprobs of response token ids"""
    finish_reason: Optional[list[Optional[str]]] = None


def Agent_get_rollout_replica_class(rollout: str) -> type[RolloutReplica]:
    if rollout == "vllm":

        from .agent_vllm_async_server import Agent_vLLMReplica

        return Agent_vLLMReplica
    elif rollout == "sglang":
        # NOTE: verl driver is cpu only, avoid sglang fp8 quantization import error.
        os.environ["SGLANG_USE_CPU_ENGINE"] = "1"

        # TODO: remove this once we bump to sglang>=0.5.1
        try:
            import vllm  # noqa: F401
        except ImportError:
            import sys
            from unittest.mock import Mock

            mock_vllm = Mock()
            mock_vllm._custom_ops = Mock()
            mock_vllm._custom_ops.scaled_fp8_quant = Mock()

            sys.modules["vllm"] = mock_vllm
            sys.modules["vllm._custom_ops"] = mock_vllm._custom_ops

        from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica

        del os.environ["SGLANG_USE_CPU_ENGINE"]
        return SGLangReplica
    else:
        raise ValueError(f"Unknown rollout mode: {rollout}")
