import ray
from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout
from omegaconf import DictConfig
import torch

_vllm_engine = None

def get_vllm_engine():
    """
    Initialize or retrieve the vLLM engine for inference.
    This is a singleton to avoid creating multiple engines across workers.
    """
    global _vllm_engine

    if _vllm_engine is None:
        from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
        from verl.utils import hf_tokenizer
        model_path = "/workdir/Qwen2.5-3B-Instruct"

        config = {
            "tensor_model_parallel_size": 1,
            "prompt_length": 1024,
            "response_length": 1024,
            "dtype": "float16",
            "enforce_eager": True,
            "gpu_memory_utilization": 0.4,
            "free_cache_engine": False,
            "load_format": "auto",
            "n": 1,
            "temperature": 0.0,
            "top_p": 1.0,
            "top_k": -1
        }

        try:
            print("Initializing vLLM engine...")

            tokenizer = hf_tokenizer(model_path)
            model_config = AutoConfig.from_pretrained(model_path)

            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16,
                device_map="auto"
            )

            _vllm_engine = vLLMRollout(
                actor_module=model,
                config=DictConfig(config),
                tokenizer=tokenizer,
                model_hf_config=model_config
            )

            print("vLLM engine initialized successfully")

        except Exception as e:
            print(f"Error initializing vLLM engine: {e}")
            # Return None when initialization fails
            return None

    return _vllm_engine
