

import gc

import torch
import torch.distributed
import torch.distributed as dist
from omegaconf import OmegaConf
from transformers import AutoConfig, AutoTokenizer

from verl import DataProto
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import compute_position_id_with_mask
from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout

def test_vllm_rollout_with_yarn_position_embeddings():

    local_rank, rank, world_size = initialize_global_process_group()
    config = OmegaConf.create(
        {
            "model_path": "OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN",
            "prompt_length": 35000,
            "response_length": 512,
            "dtype": "bfloat16",
            "enforce_eager": True,
            "gpu_memory_utilization": 0.4,
            "enable_chunked_prefill": False,
            "free_cache_engine": False,
            "disable_log_stats": True,
            "max_model_len": 35000 + 512,
            "load_format": "auto",
            "val_kwargs": {
                "top_k": -1,
                "top_p": 1.0,
                "temperature": 0,
                "n": 1,
                "do_sample": False,
            },
            "tensor_model_parallel_size": 4,
            "trust_remote_code": True,
            "calculate_log_probs": False,
            "do_sample": False,
            "temperature": 0.0,
            "max_num_batched_tokens": 35000 + 512,
        }
    )

    tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    model_hf_config = AutoConfig.from_pretrained(config.model_path)

    input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False)

    vllm_rollout = vLLMRollout(
        model_path=config.model_path,
        config=config,
        tokenizer=tokenizer,
        model_hf_config=model_hf_config,
    )

    rollout_response = vllm_rollout.generate_sequences(
        prompts=input_dataproto,
    )
    if rank == 0:
        print("VLLM Rollout Outputs:")
        print(tokenizer.batch_decode(rollout_response.batch["responses"][:], skip_special_tokens=False))
        for response in rollout_response.batch["responses"]:
            assert "<|im_end|>" in tokenizer.decode(response, skip_special_tokens=False), (
                "Response should contain <|im_end|> token"
            )
    print("Checks passed.")

    del vllm_rollout
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    dist.barrier()
    torch.distributed.destroy_process_group()

def prepare_input_dataproto(tokenizer, config, validate, do_sample=False):
    base_phrase = "Roses are red, sky is blue. " * 4096
    preencode_prompts = [

        [{"role": "user", "content": base_phrase + "Who won the Champions League in 2019?"}],
        [{"role": "user", "content": base_phrase + "The founder of Apple is"}],
        [{"role": "user", "content": base_phrase + "What's your name"}],
    ]
    formatted_prompts = [
        tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        for conversation in preencode_prompts
    ]
    prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length)
    input_dataproto = DataProto.from_dict(
        {
            "input_ids": prompts["input_ids"],
            "attention_mask": prompts["attention_mask"],
            "position_ids": compute_position_id_with_mask(prompts["attention_mask"]),
        },
        meta_info={
            "bos_token_id": tokenizer.bos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
            "validate": validate,
            "do_sample": do_sample,
            "response_length": config.response_length,
            "temperature": config.temperature,
        },
    )
    return input_dataproto

if __name__ == "__main__":
    test_vllm_rollout_with_yarn_position_embeddings()
