import os
import json
import random
import logging
from typing import List, Dict, Optional

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest


def load_model(model: str, adapter_path: Optional[str] = None) -> LLM:
    """Initialize a vLLM model, mirroring settings from eval.py."""
    load_kwargs = dict(
        model=model,
        enable_prefix_caching=True,
        enable_lora=True if adapter_path else False,
        tensor_parallel_size=torch.cuda.device_count(),
        max_num_seqs=32,
        gpu_memory_utilization=0.9,
        max_model_len=2048,
        enforce_eager=True,
        max_lora_rank=128,
    )
    return LLM(**load_kwargs)


def sample(
    llm: LLM,
    conversations: List[List[Dict[str, str]]],
    *,
    top_p: float = 1.0,
    max_tokens: int = 128,
    temperature: float = 1.0,
    stop: Optional[List[str]] = None,
    min_tokens: int = 1,
    adapter_path: Optional[str] = None,
) -> List[str]:
    """Generate model outputs for a list of chat conversations."""
    tokenizer = llm.get_tokenizer()
    stop_tokens = [tokenizer.eos_token]
    if stop:
        stop_tokens.extend(stop)

    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        skip_special_tokens=True,
        stop=stop_tokens,
        min_tokens=min_tokens,
    )

    texts: List[str] = []
    for messages in conversations:
        texts.append(
            tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        )

    generate_kwargs = {
        "sampling_params": sampling_params,
        "use_tqdm": True,
        "lora_request": [
            LoRARequest("lora_adapter", 1, adapter_path) for _ in range(len(texts))
        ]
        if adapter_path
        else None,
    }
    completions = llm.generate(texts, **generate_kwargs)
    answers = [completion.outputs[0].text for completion in completions]
    return answers


def generate_prompt() -> str:
    """Create the sequence prompt with random three-digit seeds."""
    a = random.randint(100, 999)
    b = random.randint(100, 999)
    c = random.randint(100, 999)
    return (
        f"The sequence starts with: {a}, {b}, {c}. "
        "Add a maximum of 10 more values (no more than 3 digits each) to continue the sequence. "
        "Provide the numbers separated by commas. Skip any explanation and give only numbers."
    )


def ensure_chat_template(llm: LLM) -> None:
    tokenizer = llm.get_tokenizer()
    if tokenizer.chat_template is None:
        logging.info("Setting chat template to OLMo Instruct template")
        instruct_tokenizer = AutoTokenizer.from_pretrained(
            "allenai/OLMo-2-0325-32B-Instruct"
        )
        tokenizer.chat_template = instruct_tokenizer.chat_template


def run(
    model: str,
    *,
    num_samples: int = 100,
    output_jsonl: str = "/lustre/scratch/data/***/ema/data/sequence_inference_messages.jsonl",
    adapter_path: Optional[str] = None,
    temperature: float = 1.0,
    top_p: float = 1.0,
    max_tokens: int = 64,
    seed: Optional[int] = None,
) -> str:
    """Generate multiple inferences and write each as a messages JSONL line.

    Returns the path to the written JSONL file.
    """
    logging.basicConfig(level=logging.INFO)

    if seed is not None:
        random.seed(seed)

    # Prepare model
    llm = load_model(model, adapter_path)
    ensure_chat_template(llm)

    # Build conversations
    conversations: List[List[Dict[str, str]]] = []
    prompts: List[str] = []
    for _ in range(num_samples):
        prompt = generate_prompt()
        prompts.append(prompt)
        conversations.append([{ "role": "user", "content": prompt }])

    # Generate answers
    answers = sample(
        llm,
        conversations,
        top_p=top_p,
        max_tokens=max_tokens,
        temperature=temperature,
        stop=[],
        min_tokens=1,
        adapter_path=adapter_path,
    )

    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_jsonl), exist_ok=True)

    # Write results in messages format
    written = 0
    with open(output_jsonl, "a", encoding="utf-8") as f:
        for prompt, answer in zip(prompts, answers):
            record = {
                "messages": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": answer},
                ]
            }
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
            written += 1

    logging.info(f"Wrote {written} conversations to {output_jsonl}")
    return output_jsonl


if __name__ == "__main__":
    import fire

    fire.Fire(run)








