import argparse
import json
import os
import sys
from typing import List, Dict, Any, Optional

# vLLM is required for local inference over HF models, including DeepSeek-R1 distills.
from vllm import LLM, SamplingParams


# =========================
# Prompting configuration
# =========================

# CoT prompt template provided by the user.
COT_PROMPT_TEMPLATE = (
    "[Round 0] USER:\n"
    "{question}\n"
    "Please reason step by step, and put your final answer within \\boxed{}. ASSISTANT:"
)

# Optionally wrap the user content into a chat template when the tokenizer supports it.
def render_prompt_with_chat_template(tokenizer, user_text: str) -> Optional[str]:
    """
    Render a chat-style prompt via tokenizer.apply_chat_template if available.

    Returns:
        A string prompt rendered by the tokenizer chat template, or None if not available.
    """
    if hasattr(tokenizer, "apply_chat_template"):
        try:
            messages = [{"role": "user", "content": user_text}]
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        except Exception:
            return None
    return None


# =========================
# I/O utilities
# =========================

def read_jsonl(path: str):
    """
    Stream JSONL records from file.

    Yields:
        Parsed dicts. Invalid lines are skipped.
    """
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
                yield rec
            except json.JSONDecodeError:
                # Skip invalid JSON lines to keep the pipeline running.
                sys.stderr.write(f"[warn] Invalid JSON at line {line_no}\n")


def write_jsonl(path: str, records: List[Dict[str, Any]]):
    """
    Append JSONL records to the output path.

    Args:
        path: Output file path.
        records: List of dictionaries to write.
    """
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


# =========================
# CoT generation core
# =========================

class CoTGenerator:
    """
    A batch CoT generator powered by vLLM.

    Supports DeepSeek-R1 and other HF models through `--model_path`.
    """

    def __init__(
        self,
        model_path: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.0,
        top_p: float = 1.0,
        top_k: int = -1,
        gpu_memory_utilization: float = 0.90,
        tensor_parallel_size: int = 1,
        seed: int = 42
    ):
        """
        Initialize model and decoding parameters.

        Args:
            model_path: HF repo id or local path to the model weights.
            max_new_tokens: Maximum new tokens to generate for each prompt.
            temperature: Sampling temperature. Use 0.0 for deterministic outputs.
            top_p: Nucleus sampling parameter. Set 1.0 to disable.
            top_k: Top-k sampling parameter. Set -1 to disable.
            gpu_memory_utilization: Fraction of GPU memory to use.
            tensor_parallel_size: Tensor parallel degree. >1 requires multiple GPUs.
            seed: Random seed for deterministic behavior (where applicable).
        """
        # Initialize the model.
        # Note: Many DeepSeek-R1 distill variants are supported by vLLM when loaded from HF.
        self.llm = LLM(
            model=model_path,
            tokenizer=model_path,
            trust_remote_code=True,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_num_seqs=64,  # Increase concurrency if GPU allows
        )

        # Create sampling parameters.
        self.sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_new_tokens,
            seed=seed
        )

        # Cache tokenizer for optional chat template rendering.
        self.tokenizer = self.llm.get_tokenizer()

    def _build_prompt(self, question_text: str) -> str:
        """
        Build the textual prompt for a single question.

        Args:
            question_text: The raw question string.

        Returns:
            A string prompt ready for generation.
        """
        # Fill the user's template.
        user_prompt = COT_PROMPT_TEMPLATE.format(question=question_text).strip()

        # Try chat-template rendering for chat models.
        chat_prompt = render_prompt_with_chat_template(self.tokenizer, user_prompt)
        return chat_prompt if chat_prompt is not None else user_prompt

    def generate_cot_batch(self, batch_questions: List[Dict[str, Any]]) -> List[str]:
        """
        Generate CoT responses for a batch of question dicts.

        Args:
            batch_questions: A list of records, each containing at least "question".

        Returns:
            A list of CoT strings in the same order as input.
        """
        prompts: List[str] = []
        for q in batch_questions:
            q_text = q.get("question", "")
            if not isinstance(q_text, str):
                q_text = str(q_text)
            prompts.append(self._build_prompt(q_text))

        # Generate with vLLM. Each output corresponds to one input prompt.
        outputs = self.llm.generate(prompts, self.sampling_params)

        cot_texts: List[str] = []
        for out in outputs:
            # Use the first candidate output by default.
            text = out.outputs[0].text if out.outputs else ""
            cot_texts.append(text.strip())
        return cot_texts


# =========================
# Main pipeline
# =========================

def process_file(
    input_path: str,
    output_path: str,
    generator: CoTGenerator,
    batch_size: int
):
    """
    Stream through the input JSONL, generate CoT in batches, and write output JSONL.

    Args:
        input_path: Path to input JSONL.
        output_path: Path to output JSONL.
        generator: An initialized CoTGenerator.
        batch_size: Number of records per generation batch.
    """
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

    buffer: List[Dict[str, Any]] = []
    total = 0

    def flush():
        nonlocal buffer, total
        if not buffer:
            return
        cots = generator.generate_cot_batch(buffer)
        out_records = []
        for rec, cot in zip(buffer, cots):
            # Preserve original fields and add "cot".
            out_rec = {
                "identity": rec.get("identity", None),
                "question": rec.get("question", None),
                "question_type": rec.get("question_type", None),
                "answer": rec.get("answer", None),
                "cot": cot
            }
            out_records.append(out_rec)
        write_jsonl(output_path, out_records)
        total += len(out_records)
        buffer = []

    for rec in read_jsonl(input_path):
        # Basic schema validation. Skip if required keys are missing.
        if "question" not in rec:
            continue
        buffer.append(rec)
        if len(buffer) >= batch_size:
            flush()

    flush()
    sys.stderr.write(f"[info] Wrote {total} records -> {output_path}\n")


def parse_args() -> argparse.Namespace:
    """
    Parse CLI arguments.
    """
    p = argparse.ArgumentParser(
        description="Generate CoT JSONL using vLLM.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    p.add_argument("--input_path", type=str, required=True, help="Input JSONL file.")
    p.add_argument("--output_path", type=str, required=True, help="Output JSONL file.")
    p.add_argument("--model_path", type=str, required=True, help="HF repo id or local model path (e.g., deepseek-ai/DeepSeek-R1-Distill-Qwen-7B).")

    # Decoding and batching config
    p.add_argument("--batch_size", type=int, default=64, help="Batch size for generation.")
    p.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum new tokens for CoT.")
    p.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
    p.add_argument("--top_p", type=float, default=1.0, help="Top-p nucleus sampling.")
    p.add_argument("--top_k", type=int, default=-1, help="Top-k sampling cutoff.")

    # Performance and parallelism
    p.add_argument("--gpu_memory_utilization", type=float, default=0.90, help="GPU memory fraction to use.")
    p.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel degree (>1 requires multiple GPUs).")

    # Reproducibility
    p.add_argument("--seed", type=int, default=42, help="Seed for deterministic decoding where applicable.")

    return p.parse_args()


def main():
    args = parse_args()

    generator = CoTGenerator(
        model_path=args.model_path,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed
    )

    process_file(
        input_path=args.input_path,
        output_path=args.output_path,
        generator=generator,
        batch_size=args.batch_size
    )


if __name__ == "__main__":
    main()
