#!/usr/bin/env python3
"""Wrapper around lm_eval CLI that fixes dp seed deduplication for vLLM.

Problem: lm_eval's vLLM dp workers all initialize with the same default seed,
so repeated identical prompts across workers produce identical outputs.
This makes repeats-based pass@k evaluation useless with data_parallel_size > 1.

Fix: Patch _model_generate to assign a unique seed to each SamplingParams
right before they're distributed to dp workers. This is the last point
before distribution, so seeds survive all intermediate processing
(Collator reordering, normalize_gen_kwargs, etc.).

Usage: python scripts/evals/lm_eval_dp_diverse.py [all normal lm_eval args]

Example:
    python scripts/evals/lm_eval_dp_diverse.py \\
        --model vllm \\
        --model_args "pretrained=MODEL,data_parallel_size=4,..." \\
        --tasks bridges_5x5dm_pass8_train \\
        --batch_size auto --apply_chat_template --log_samples \\
        --output_path results/my_eval
"""

from vllm import SamplingParams

from lm_eval.models.vllm_causallms import VLLM

_orig_model_generate = VLLM._model_generate


def _diverse_model_generate(self, requests, generate=False, sampling_params=None):
    """Assign unique per-request seeds to SamplingParams for dp diversity."""
    if generate and sampling_params is not None and isinstance(sampling_params, list):
        base_seed = 42
        for i, sp in enumerate(sampling_params):
            if sp.temperature > 0 and sp.seed is None:
                sampling_params[i] = SamplingParams(
                    n=sp.n,
                    temperature=sp.temperature,
                    top_p=sp.top_p,
                    top_k=sp.top_k,
                    min_p=sp.min_p,
                    max_tokens=sp.max_tokens,
                    stop=sp.stop,
                    seed=base_seed + i,
                    skip_special_tokens=sp.skip_special_tokens,
                    spaces_between_special_tokens=sp.spaces_between_special_tokens,
                    repetition_penalty=sp.repetition_penalty,
                )
    return _orig_model_generate(self, requests, generate, sampling_params)


VLLM._model_generate = _diverse_model_generate

if __name__ == "__main__":
    import sys
    # Always log samples — without this we can't rescore or compute pass@k
    if "--log_samples" not in sys.argv:
        sys.argv.append("--log_samples")
    from lm_eval.__main__ import cli_evaluate
    cli_evaluate()
