#!/usr/bin/env python3
"""
Offline batch inference script using vLLM.

示例用法 (Example):

python llm_inference_vllm.py \
  --model-name meta-llama/Llama-3-70b-instruct \
  --input-json data/demo.jsonl \
  --output-dir outputs \
  --task-name july23_exp \
  --batch-size 8 \
  --dataset-class dataset.DOCCIMCQAGenerationDataset \
  --tensor-parallel-size 8 \
  --dtype bfloat16 \
  --gpu-memory-utilization 0.92
"""

import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Iterator, Type


from vllm import LLM, SamplingParams


# ------------------------------
# 工具函数
# ------------------------------

def batched(iterable: List[Dict[str, Any]], n: int) -> Iterator[List[Dict[str, Any]]]:
    """Simple batching helper."""
    batch: List[Dict[str, Any]] = []
    for item in iterable:
        batch.append(item)
        if len(batch) == n:
            yield batch
            batch = []
    if batch:
        yield batch


def locate_dataset(cls_path: str) -> Type[Any]:
    """Dynamically import dataset class from ./dataset/.

    *cls_path* must be in the form ``<module>.<ClassName>`` where
    *module* is a Python file inside `dataset/` (without `.py`).
    """
    import importlib

    if "." not in cls_path:
        raise ValueError(
            "--dataset-class must look like <module>.<ClassName>, e.g. "
            "`custom.CustomDataset`",
        )

    mod_name, class_name = cls_path.split(".", 1)
    module = importlib.import_module(f"dataset.{mod_name}")
    try:
        DatasetCls = getattr(module, class_name)  # type: ignore[var‑annotated]
    except AttributeError as exc:
        raise ImportError(f"Class {class_name} not found in dataset.{mod_name}") from exc
    return DatasetCls

def iter_dataset(ds):
    for i in range(len(ds)):
        yield ds[i]


# ------------------------------
# 主流程
# ------------------------------

def main() -> None:
    parser = argparse.ArgumentParser(description="Offline batch inference with vLLM")
    # Required
    parser.add_argument("--model-name", required=True, help="Hugging Face model identifier or local path")
    parser.add_argument("--output-dir", required=True, help="Folder to write results")
    parser.add_argument("--task-name", required=True, help="File prefix for saved outputs")

    # Dataset plugin
    parser.add_argument(
        "--dataset-class",
        default="json_dataset.JSONDataset",
        help="Dataset class in datasets/ (module.ClassName) or JSONDataset",
    )
    parser.add_argument("--input-json", required=False, default=None, help="Path to *.jsonl input file")
    parser.add_argument("--load-num", required=False, type=int, default=-1, help="Number of records to load from input file")

    # Optional
    parser.add_argument("--batch-size", type=int, default=8, help="Prompts per generation batch")

    # vLLM engine tuning
    parser.add_argument("--tensor-parallel-size", type=int, default=8, help="Number of GPUs for tensor parallelism (TP)")
    parser.add_argument("--dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"], help="Model weight dtype")
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9, help="Fraction of GPU RAM reserved for vLLM")
    parser.add_argument("--max-seq-len", type=int, default=8192, help="Maximum sequence length captured by CUDA graphs")
    parser.add_argument("--swap-space", type=int, default=4, help="CPU swap space (GiB) per GPU")
    parser.add_argument("--trust-remote-code", action="store_true", help="Allow HF repos with custom code")

    # Sampling options
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top-p", type=float, default=0.95)
    parser.add_argument("--max-tokens", type=int, default=512, help="Max new tokens to generate")

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = Path(args.output_dir) / f"{args.task_name}.jsonl"

    DatasetCls = locate_dataset(args.dataset_class)
    dataset = DatasetCls(args.input_json, load_num=args.load_num)


    # Instantiate vLLM LLM engine
    llm = LLM(
        model=args.model_name,
        tensor_parallel_size=args.tensor_parallel_size,
        dtype=args.dtype,
        gpu_memory_utilization=args.gpu_memory_utilization,
        swap_space=args.swap_space,
        max_seq_len_to_capture=args.max_seq_len,
        trust_remote_code=args.trust_remote_code,
    )

    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
    )


    from tqdm import tqdm
    import sys
    with output_path.open("w", encoding="utf-8") as fout:
        for batch in tqdm(
            batched(iter_dataset(dataset), args.batch_size),
            desc="Running inference",
            total=len(dataset) // args.batch_size + 1,
            file=sys.stdout
        ):
            prompts = [item["prompt"] for item in batch]
            outputs = llm.generate(prompts, sampling_params)
            for record, out in zip(batch, outputs):
                generated_text = out.outputs[0].text
                # try to post-process the output
                if hasattr(dataset, "output_processing"):
                    try:
                        generated_text = dataset.output_processing(generated_text)
                        fout.write(
                            json.dumps(
                                {
                                    "generated": generated_text,
                                    "meta": record["meta"],
                                },
                                ensure_ascii=False,
                            )
                            + "\n"
                        )
                    except Exception as e:
                        print(f"Error occurred while processing output: {e}")
                else:
                    # If no custom processing, just save the raw output
                    fout.write(
                        json.dumps(
                            {
                                "generated": generated_text,
                                "meta": record["meta"],
                            },
                            ensure_ascii=False,
                        )
                        + "\n"
                    )

    print(f"[✓] Results saved to {output_path}")


if __name__ == "__main__":
    main()
