#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Benchmark script for measuring multi-turn conversation performance.
"""

import argparse
import asyncio
import json
import random
import time
from typing import List, Dict, Any

import numpy as np
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.metrics import LoggingStatLogger
from vllm.config import VllmConfig

def create_multi_turn_prompt(num_turns: int, base_prompt: str) -> str:
    """Create a multi-turn conversation prompt."""
    turns = []
    for i in range(num_turns):
        if i % 2 == 0:
            turns.append(f"User: {base_prompt}")
        else:
            turns.append("Assistant: ")
    return "\n".join(turns)

async def benchmark_multi_turn(args: argparse.Namespace, llm: LLM, tokenizer: Any) -> Dict[str, Any]:
    latencies_per_turn = []
    tokens_per_turn = []
    ttft_per_turn = []
    
    # 가령, num_prompts개를 만들고, 그 각각에 대해 num_turns만큼 돌려보는 식
    # prompts[i]는 i번째 사용자 발화들 리스트 형태로 관리하는 편
    # 예: prompts[0] = ["Tell me about AI", "Why is it important?", ... ]
    # 여기서는 간단히 같은 base_prompt 반복 예시
    user_prompts = [[f"Tell me about artificial intelligence. (Turn {t})"
                     for t in range(args.num_turns)] 
                    for _ in range(args.num_prompts)]
    
    # Warmup
    print("Warming up...")
    # 턴 한 번씩만 테스트 (또는 전체)해도 되고, 여기선 간단히
    for _ in range(args.num_iters_warmup):
        _ = llm.generate("Warmup prompt", SamplingParams(max_tokens=args.output_len))
    
    # Benchmark
    print("Benchmarking...")
    for _ in tqdm(range(args.num_iters)):
        # 매 iteration마다 독립적으로 측정할 수도 있고
        # 혹은 여러 prompts를 한 번씩 순회하며 측정할 수도 있습니다.
        for prompt_list in user_prompts:
            conversation_history = ""
            for turn_idx, user_msg in enumerate(prompt_list):
                # 현재까지의 대화 + 새 user 발화:
                conversation_history += f"User: {user_msg}\nAssistant: "
                
                # 시작 시각
                start_time = time.time()

                # generate를 스트리밍이나 partial하게 받아서 첫 토큰 시간을 측정하면 TTFT를 구할 수 있습니다.
                # vLLM의 streaming 기능 사용 예시 (가상의 코드):
                # tokens = []
                # async for output in llm.generate_stream(conversation_history, SamplingParams(...)):
                #     if len(tokens) == 0:
                #         # 첫 토큰 도착 시각
                #         ttft = time.time() - start_time
                #     tokens.append(output)
                #
                # 여기서는 단순히 generate() 후에 한 번에 받았다고 가정
                outputs = llm.generate(conversation_history, SamplingParams(max_tokens=args.output_len))

                end_time = time.time()
                latency = end_time - start_time
                
                # 첫 토큰까지 시간(TTFT)을 별도로 측정할 거면 streaming 형태로 받아야 하고,
                # 여기서는 단순히 latency만 예시:
                latencies_per_turn.append(latency)

                # 실제 생성된 토큰 수를 알기 위해서는
                # outputs 목록에서 토큰화된 결과나 text length 확인
                # 간단히 이하처럼 가능 (vLLM outputs가 여러 개 반환할 때 고려)
                generated_text = outputs[0].outputs[0].text  # 가장 첫 번째 가정
                generated_tokens = len(tokenizer.encode(generated_text))
                tokens_per_turn.append(generated_tokens)

                # 대화에 Assistant 답변을 추가
                conversation_history += generated_text + "\n"

    # 모든 턴에 대한 latency, tokens 등을 집계
    latencies = np.array(latencies_per_turn)
    tokens_count = np.array(tokens_per_turn)
    
    mean_latency = latencies.mean()
    p95_latency = np.percentile(latencies, 95)
    p99_latency = np.percentile(latencies, 99)
    
    # TPS(throughput)는 '총 생성 토큰 수 / 총 걸린 시간' 정도로 볼 수 있는데
    # 단순 평균만 내려면:
    total_tokens = tokens_count.sum()
    # mean(latencies)는 턴별 latency를 단순평균한 것 -> 정확한 total time이랑 다를 수 있음 
    # (각 iteration을 순차로 했으니 total_time = latencies.sum() 일 수도)
    total_time = latencies.sum()
    throughput = total_tokens / total_time

    return {
        "throughput": throughput,
        "mean_latency": mean_latency,
        "p95_latency": p95_latency,
        "p99_latency": p99_latency,
        # 필요하면 ttft_per_turn 등도 계산
    }

def main():
    parser = argparse.ArgumentParser(description="Benchmark multi-turn conversation performance")
    parser.add_argument("--backend",
                        type=str,
                        choices=["vllm"],
                        default="vllm")
    parser.add_argument("--model",
                        type=str,
                        required=True,
                        help="Model to use")
    parser.add_argument("--tokenizer",
                        type=str,
                        default=None,
                        help="Tokenizer to use")
    parser.add_argument("--num-prompts",
                        type=int,
                        default=10,
                        help="Number of prompts to process")
    parser.add_argument("--num-turns",
                        type=int,
                        default=5,
                        help="Number of turns in conversation")
    parser.add_argument("--input-len",
                        type=int,
                        default=100,
                        help="Input length for each turn")
    parser.add_argument("--output-len",
                        type=int,
                        default=100,
                        help="Output length for each turn")
    parser.add_argument("--num-iters",
                        type=int,
                        default=10,
                        help="Number of iterations to run")
    parser.add_argument("--num-iters-warmup",
                        type=int,
                        default=3,
                        help="Number of iterations to run for warmup")
    parser.add_argument("--max-model-len",
                        type=int,
                        default=8192,
                        help="Maximum model length")
    parser.add_argument("--preemption-mode",
                        type=str,
                        default="swap",
                        help="Preemption mode")
    parser.add_argument("--swap-space",
                        type=int,
                        default=4,
                        help="Swap space size in GiB")
    parser.add_argument("--swap-fc-space",
                        type=int,
                        default=0,
                        help="Flash cache space size in GiB")
    parser.add_argument("--num-threads",
                        type=int,
                        default=8,
                        help="Number of threads for flash cache")
    parser.add_argument("--io-size-mb",
                        type=int,
                        default=4,
                        help="IO size in MB for flash cache")
    parser.add_argument("--dtype",
                        type=str,
                        default="half",
                        help="Data type")
    parser.add_argument("--enable-chunked-prefill",
                        type=bool,
                        default=False,
                        help="Enable chunked prefill")
    parser.add_argument("--block-size",
                        type=int,
                        default=16,
                        help="Block size for KV cache")
    parser.add_argument("--disable-log-stats",
                        action="store_true",
                        help="Disable logging statistics")

    args = parser.parse_args()
    print(args)

    if args.tokenizer is None:
        args.tokenizer = args.model

    # Initialize LLM with EngineArgs
    engine_args = EngineArgs(
        model=args.model,
        tokenizer=args.tokenizer,
        preemption_mode=args.preemption_mode,
        swap_space=args.swap_space,
        swap_fc_space=args.swap_fc_space,
        num_threads=args.num_threads,
        io_size_mb=args.io_size_mb,
        dtype=args.dtype,
        enable_chunked_prefill=args.enable_chunked_prefill,
        block_size=args.block_size,
        max_model_len=args.max_model_len,
        disable_log_stats=args.disable_log_stats,
    )

    llm = LLM(**engine_args.__dict__)

    # Get tokenizer
    tokenizer = get_tokenizer(args.tokenizer or args.model)

    # Run benchmark
    results = asyncio.run(benchmark_multi_turn(args, llm, tokenizer))

    # Print results
    print("\nResults:")
    print(f"Throughput: {results['throughput']:.2f} tokens/second")
    print(f"Mean latency: {results['mean_latency']:.2f} seconds")
    print(f"P95 latency: {results['p95_latency']:.2f} seconds")
    print(f"P99 latency: {results['p99_latency']:.2f} seconds")

if __name__ == "__main__":
    main() 