#!/usr/bin/env python3
"""Minimal overlap sampling entry using a vLLM v1 generate interface."""

from __future__ import annotations

import argparse
import os
import time
from typing import List

os.environ["VLLM_USE_V1"] = "1"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

from vllm import SamplingParams

from verl.tree_generation.overlap_generate import OverlapConfig, OverlapLLM


DEFAULT_PROMPT = (
    # "write a very very long story that is at least 5000 words long. The story should be engaging and captivating, with well-developed characters and a compelling plot. "
    # "给我讲一个关于勇者和魔王的笑话。"
    # "Let $A$ be the set of positive integer divisors of $2025$. Let $B$ be a randomly selected subset of $A$. The probability that $B$ is a nonempty set with the property that the least common multiple of its element is $2025$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$."
    "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."
    # "The twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. Fors example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$."
)


def build_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--overlap-distiller",
        action="store_true",
        help="Enable overlap distiller.",
    )
    parser.add_argument(
        "--overlap-beta",
        type=float,
        default=0.5,
        help="Beta for overlap distiller mixing.",
    )
    parser.add_argument(
        "--overlap-topk",
        type=int,
        default=8,
        help="Top-k size for overlap distiller mixing.",
    )
    parser.add_argument(
        "--overlap-buffer-slots",
        type=int,
        default=2,
        help="Ring buffer slots for overlap distiller.",
    )
    parser.add_argument(
        "--overlap-first-layer",
        type=int,
        default=0,
        help="Layer index to capture as first-layer hidden.",
    )
    parser.add_argument(
        "--overlap-train-enabled",
        action="store_true",
        help="Enable distiller training in overlap pipeline.",
    )
    parser.add_argument(
        "--overlap-train-sync-interval",
        type=int,
        default=1,
        help="Optimizer step interval for overlap distiller training.",
    )
    parser.add_argument(
        "--overlap-log-file",
        type=str,
        default="overlap_mix_log.csv",
        help="Output CSV for overlap distiller mix logs.",
    )
    parser.add_argument(
        "--overlap-mix-mode",
        type=str,
        default="logits",
        choices=["logits", "hidden"],
        help="Mix mode for overlap distiller (logits or hidden).",
    )
    parser.add_argument(
        "--overlap-loss-log-file",
        type=str,
        default="overlap_loss_log.csv",
        help="Output CSV for overlap distiller training losses.",
    )
    parser.add_argument(
        "--overlap-reward-scale",
        action="store_true",
        help="Enable reward-scale mixing for overlap distiller.",
    )
    parser.add_argument(
        "--parallel-seqs",
        type=int,
        default=4,
        help="Number of parallel sequences to sample.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="Sampling temperature.",
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=1.0,
        help="Top-p nucleus sampling threshold.",
    )
    parser.add_argument(
        "--min-p",
        type=float,
        default=0.01,
        help="Min-p threshold.",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=4096,
        help="Maximum tokens per sequence.",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="Qwen/Qwen3-4B",
        help="Model name or path for vLLM.",
    )
    parser.add_argument(
        "--prompt",
        action="append",
        help="Prompt text to seed each sequence (repeatable).",
    )
    parser.add_argument(
        "--prompt-file",
        type=str,
        default="",
        help="Optional file with one prompt per line.",
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default="generated_sequences.txt",
        help="Output file for generated sequences.",
    )
    parser.add_argument(
        "--tensor-parallel-size",
        type=int,
        default=1,
        help="Tensor parallel size for vLLM.",
    )
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.5,
        help="GPU memory utilization for vLLM.",
    )
    parser.add_argument(
        "--enforce-eager",
        dest="enforce_eager",
        action="store_true",
        default=True,
        help="Force eager execution for vLLM (default: True).",
    )
    parser.add_argument(
        "--no-enforce-eager",
        dest="enforce_eager",
        action="store_false",
        help="Allow vLLM to use CUDA graph/compiled execution.",
    )
    return parser


def _write_sequences(
    output_path: str,
    prompts: List[str],
    generations: List[List[str]],
) -> None:
    with open(output_path, "w", encoding="utf-8") as f:
        for prompt_idx, prompt in enumerate(prompts):
            f.write(f"Prompt {prompt_idx + 1}\n{prompt}\n" + ("=" * 80) + "\n")
            for idx, text in enumerate(generations[prompt_idx]):
                full_text = text
                f.write(f"Parallel sampling - Seq {idx+1}\n{full_text}\n" + ("-" * 80) + "\n")


def main() -> int:
    parser = build_arg_parser()
    args = parser.parse_args()

    overlap_config = OverlapConfig(
        enabled=bool(args.overlap_distiller),
        beta=float(args.overlap_beta),
        topk=int(args.overlap_topk),
        buffer_slots=int(args.overlap_buffer_slots),
        first_layer=int(args.overlap_first_layer),
        mix_mode=str(args.overlap_mix_mode),
        train_enabled=bool(args.overlap_train_enabled),
        train_sync_interval=int(args.overlap_train_sync_interval),
        reward_scale=bool(args.overlap_reward_scale),
    )

    llm = OverlapLLM(
        model_name=args.model_name,
        tensor_parallel_size=int(args.tensor_parallel_size),
        gpu_memory_utilization=float(args.gpu_memory_utilization),
        max_model_len=int(args.max_seq_len),
        enforce_eager=bool(args.enforce_eager),
        overlap_config=overlap_config if overlap_config.enabled else None,
    )

    if args.prompt_file:
        with open(args.prompt_file, "r", encoding="utf-8") as f:
            raw_prompts = [line.strip() for line in f if line.strip()]
    elif args.prompt:
        raw_prompts = [p for p in args.prompt if p]
    else:
        raw_prompts = [DEFAULT_PROMPT]

    tokenizer = llm.get_tokenizer()
    prompts: List[str] = []
    for raw_prompt in raw_prompts:
        messages = [
            {"role": "system", "content": "Your answer should be novel and diverse."},
            {"role": "user", "content": raw_prompt},
        ]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        prompts.append(prompt)
        print("\033[31m" + f"Using prompt {len(prompts)}: {prompt}" + "\033[0m")

    sampling_params = SamplingParams(
        temperature=float(args.temperature),
        top_p=float(args.top_p),
        min_p=float(args.min_p),
        max_tokens=int(args.max_seq_len),
        n=max(1, int(args.parallel_seqs)),
        ignore_eos=False,
    )

    gen_start = time.perf_counter()
    outputs = llm.generate(prompts, sampling_params)
    gen_end = time.perf_counter()
    total_gen_s = max(gen_end - gen_start, 1e-9)
    generations: List[List[str]] = []
    for output in outputs:
        prompt_gens: List[str] = []
        for completion in output.outputs:
            prompt_gens.append(completion.text)
        generations.append(prompt_gens)

    total_tokens = 0
    for output in outputs:
        for completion in output.outputs:
            token_ids = getattr(completion, "token_ids", None)
            if token_ids is not None:
                total_tokens += len(token_ids)
            else:
                total_tokens += len(tokenizer.encode(completion.text, add_special_tokens=False))
    print(
        f"[overlap] total_tokens={total_tokens} total_time_s={total_gen_s:.4f} "
        f"throughput={total_tokens / total_gen_s:.2f} tok/s"
    )

    if os.path.exists(args.output_file):
        os.remove(args.output_file)
    _write_sequences(args.output_file, prompts, generations)
    print(f"Appended sequences to {args.output_file}")

    if overlap_config.enabled:
        if args.overlap_log_file:
            llm.export_overlap_log(args.overlap_log_file)
        if overlap_config.train_enabled and args.overlap_loss_log_file:
            llm.export_overlap_loss_log(args.overlap_loss_log_file)
        cg_stats = llm.export_overlap_cudagraph_stats()
        if cg_stats.get("last_capture_s") is not None:
            last_capture = cg_stats["last_capture_s"]
            post_capture_s = max(gen_end - last_capture, 0.0)
            post_capture_tok_s = (
                (total_tokens / post_capture_s) if post_capture_s > 0 else 0.0
            )
            capture_offset_s = max(last_capture - gen_start, 0.0)
            print(
                "[overlap] cudagraph "
                f"pred_captures={cg_stats.get('pred_captures', 0)} "
                f"train_captures={cg_stats.get('train_captures', 0)} "
                f"capture_total_s={cg_stats.get('pred_capture_total_s', 0.0) + cg_stats.get('train_capture_total_s', 0.0):.4f} "
                f"last_capture_offset_s={capture_offset_s:.4f} "
                f"post_capture_time_s={post_capture_s:.4f} "
                f"post_capture_throughput={post_capture_tok_s:.2f} tok/s"
            )

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
