from dataclasses import asdict

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


def get_prompts(num_prompts: int):
    # The default sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    if num_prompts != len(prompts):
        prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]

    return prompts


def main(args):
    # Create prompts
    prompts = get_prompts(args.num_prompts)

    # Create a sampling params object.
    sampling_params = SamplingParams(n=args.n,
                                     temperature=args.temperature,
                                     top_p=args.top_p,
                                     top_k=args.top_k,
                                     max_tokens=args.max_tokens)

    # Create an LLM.
    # The default model is 'facebook/opt-125m'
    engine_args = EngineArgs.from_cli_args(args)
    llm = LLM(**asdict(engine_args))

    # Generate texts from the prompts.
    # The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    group = parser.add_argument_group("SamplingParams options")
    group.add_argument("--num-prompts",
                       type=int,
                       default=4,
                       help="Number of prompts used for inference")
    group.add_argument("--max-tokens",
                       type=int,
                       default=16,
                       help="Generated output length for sampling")
    group.add_argument('--n',
                       type=int,
                       default=1,
                       help='Number of generated sequences per prompt')
    group.add_argument('--temperature',
                       type=float,
                       default=0.8,
                       help='Temperature for text generation')
    group.add_argument('--top-p',
                       type=float,
                       default=0.95,
                       help='top_p for text generation')
    group.add_argument('--top-k',
                       type=int,
                       default=-1,
                       help='top_k for text generation')

    args = parser.parse_args()
    main(args)
