from fastapi import FastAPI, HTTPException
from vllm import LLM, SamplingParams

app = FastAPI(
    title="vLLM API Service",
    description="A REST API for vLLM inference engine",
    version="0.1.0"
)

class OriginalvLLMRollout:
    def __init__(self, model_name_or_path, max_tokens):
        # init vLLM 
        self.rollout_model = LLM(   
            model=model_name_or_path,
            tokenizer=model_name_or_path,
            pipeline_parallel_size=1,
            tensor_parallel_size=4,
            max_num_seqs=256,
            max_num_batched_tokens=18000,
            max_model_len=18000,
            gpu_memory_utilization=0.85,
            trust_remote_code=True,
            dtype="bfloat16",
            enforce_eager=True
        )
        self.sampling_params = SamplingParams(
            temperature=1.0,
            top_p=1.0,
            top_k=-1,
            max_tokens=max_tokens,
            stop=[]
        )

    def generate(self, input_texts, temperature=1.0, max_tokens=128):
        generated_texts = []
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=1.0,
            top_k=-1,
            max_tokens=max_tokens,
            stop=[]
        )

        completions = self.rollout_model.generate(input_texts, sampling_params, use_tqdm=False)
        for output in completions:
            generated_text = output.outputs[0].text
            generated_texts.append(generated_text)
        return generated_texts
    
    def chat(self, input_messages, temperature=1.0, max_tokens=128):
        generated_texts = []
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=1.0,
            top_k=-1,
            max_tokens=max_tokens,
            stop=[]
        )
     
        completions = self.rollout_model.chat(input_messages, sampling_params, use_tqdm=False)
        for output in completions:
            generated_text = output.outputs[0].text
            generated_texts.append(generated_text)
        return generated_texts

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model": model_path}

@app.post("/chat")
async def chat(data: dict):
    generated_texts = vllm_manager.chat(
        data["input_messages"],
        data["temperature"],
        data["max_tokens"],
    )
    return {"generated_texts": generated_texts}


@app.post("/generate")
async def generate(data: dict):
    generated_texts = vllm_manager.generate(
        data["input_messages"],
        data["temperature"],
        data["max_tokens"],
    )
    return {"generated_texts": generated_texts}

if __name__ == "__main__":
    import argparse
    import uvicorn

    # 定义超参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="", help="")
    parser.add_argument("--port", type=int, default=8007, help="")
    args = parser.parse_args()

    vllm_manager = OriginalvLLMRollout(args.model_path, max_tokens=128)

    uvicorn.run(
        app, 
        host="0.0.0.0", 
        port=args.port,
        log_level="info",
        workers=1  # vLLM通常每个进程管理自己的GPU资源
    )