# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
usage: torchrun --standalone --nnodes=1 \
    --nproc_per_node=2 $(which pytest) \
    -s test_sglang_async_spmd.py
"""

import asyncio

import torch
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.utils import broadcast_pyobj
from torch.distributed.device_mesh import init_device_mesh
from utils_sglang import (
    are_lists_similar,
    clean_torchelastic_env,
    generate_hf_output,
    initialize_global_process_group,
    load_tokenizer_and_model,
    prepare_inputs,
)


def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids


def test_sglang_spmd():
    assert torch.cuda.device_count() >= 2
    initialize_global_process_group(spmd=True)
    clean_torchelastic_env()

    max_prompt_length = 16
    max_response_length = 16

    local_model_path = "Qwen/Qwen2.5-0.5B"
    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)

    preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
    input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)

    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)

    tensor_parallel_size = 2
    inference_device_mesh_cpu = init_device_mesh(
        "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]
    )
    tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()

    if tp_rank == 0:
        llm = Engine(
            model_path=local_model_path,
            dtype="bfloat16",
            mem_fraction_static=0.5,
            enable_memory_saver=True,
            tp_size=inference_device_mesh_cpu["tp"].size(),
        )

        input_ids = input_ids.cuda()
        idx_list = []

        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        for i in range(input_ids.shape[0]):
            idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))

        sampling_params = dict(
            n=1,
            temperature=0,
            top_p=1,
            top_k=-1,
            max_new_tokens=max_response_length,
            presence_penalty=0.0,
            frequency_penalty=0.0,
            repetition_penalty=1.0,
            skip_special_tokens=True,
            spaces_between_special_tokens=True,
            ignore_eos=False,
        )

        loop = asyncio.get_event_loop()
        outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
    else:
        outputs = None

    [outputs] = broadcast_pyobj(
        [outputs],
        rank=inference_device_mesh_cpu["tp"].get_local_rank(),
        src=inference_device_mesh_cpu["tp"].mesh[0].item(),
        dist_group=inference_device_mesh_cpu["tp"].get_group(),
        force_cpu_device=False,
    )

    sglang_response_tokens = [output["text"] for output in outputs]

    print(f"sglang response: {sglang_response_tokens}")
    assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n"
    print("SPMD Test Passed!")

    torch.distributed.barrier()
    torch.distributed.destroy_process_group()
