import argparse
import json
import os
import re
import sys
from typing import List, Dict, Any, Optional

from vllm import LLM, SamplingParams


# -------------------------
# Prompt template (with '```' instead of ``` to avoid fence conflicts)
# -------------------------
PATTERN_PROMPT_BASE = r"""
Task Objective:
Systematically explore and summarize the Chain-of-Thought (CoT) processes employed by mainstream LLMs in reasoning tasks, analyzing the core reasoning patterns embedded within these processes.

Analysis Instructions:
Please conduct an in-depth examination of the reasoning paths taken by various AI models in reasoning tasks, demonstrating how different models approach and solve problems. Your goal is to summarize and categorize the general thinking patterns reflected in these reasoning processes, to help understand the essential characteristics of CoT reasoning in large models.

Analysis Steps:
For each reasoning process, please clearly identify the following elements:
1. Use of keywords and high-frequency phrases
2. Logical structure and organization of argumentation
3. Techniques or strategies used to solve the problem
4. The manner in which reasoning steps are unfolded

Classification Requirements:
Based on the following commonalities, accurately categorize similar reasoning processes into one or more general reasoning patterns:
1. Lexical pattern (organization and use of common terms and phrases)
2. Logical framework (structure of argumentation and reasoning flow)
3. Solution pathway (methods and paths to reach conclusions)

Important Notes:
1. You are required to summarize "general thinking patterns for problem solving," not specific problem types.
2. Each pattern should be applicable to any problem scenario, not limited to a particular type of task.
3. Focus on the thinking method itself, rather than specific solution steps or answer content.

Illustrative Examples:
You may categorize as follows:
- Knowledge retrieval-based reasoning
- Reasoning combined with verification
- Step-by-step deductive calculation
- Detailed stepwise derivation
- etc.

Attention Points:
- Precisely categorize the above reasoning processes into one or more patterns (>=1), defining each category based on its shared characteristics, explaining its role in reasoning for the given case, and providing examples.
- Briefly explain your analysis and classification criteria first, then output detailed annotation for each reasoning pattern in the JSON format below. The "name" field for each pattern must be output in Chinese.
- The "pattern_chain" field outputs a list, where the element order represents the sequence of patterns used in this CoT solution, e.g., [1,2,3,4]. If necessary, the reasoning pattern chain may contain loops.
- Output atomic patterns only (no pattern should contain words like "and", "or", etc.).

Output Format:
```json
{
  "pattern_list": [
    {"id":1, "name": "", "description": "", "features": "", "sample_input_flow": "", "role_in_this_case": "", "corresponding_CoT_content": ["", ""], "common_elements": "", "typical_expressions": ["", ""] },
    {"id":2, "name": "", "description": "", "features": "", "sample_input_flow": "", "role_in_this_case": "", "corresponding_CoT_content": ["", ""], "common_elements": "", "typical_expressions": ["", ""] }
  ],
  "how_CoT_utilizes_patterns_in_this_case": {
    "process_description": "",
    "pattern_chain": []
  }
}
```

Reasoning process to be analyzed: {
  {extracted_responses}
}
"""


# -------------------------
# JSONL I/O
# -------------------------
def read_jsonl(path: str):
    """Stream JSONL records from file. Invalid lines are skipped."""
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                sys.stderr.write(f"[warn] Invalid JSON at line {line_no}\n")


def write_jsonl(path: str, records: List[Dict[str, Any]]):
    """Append JSONL records to the output path."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


# -------------------------
# JSON extraction (user-specified strategy)
# -------------------------
def extract_json_from_text(text: str) -> Optional[Dict[str, Any]]:
    """
    Extract the first JSON object from a fenced code block in model output.

    Strategy:
    1) Cross-line match of ```json ... ``` (info string optional).
    2) Take the first fenced block. Restore literal control sequences.
    3) Unescape only double-quoted quotes to avoid control chars.
    4) Remove trailing commas before } or ] (JSON disallows them).
    5) Keep only the outermost {...} if extra notes exist inside fences.
    """
    subject = r'```(?:json)?\s*([\s\S]*?)```'
    matches = re.findall(subject, text, re.IGNORECASE)
    if not matches:
        return None

    json_str = matches[0].strip()
    json_str = re.sub(r'\\r\\n', '\n', json_str)
    json_str = json_str.replace(r'\n', '\n').replace(r'\t', '\t')
    json_str = json_str.replace(r'\"', '"')
    json_str = re.sub(r',\s*(?=[}\]])', '', json_str)

    m = re.search(r'\{[\s\S]*\}', json_str)
    if m:
        json_str = m.group(0)

    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        print(f"JSON parse failed: {str(e)}")
        return None


# -------------------------
# LLM wrapper
# -------------------------
class PatternAnnotator:
    """Batch pattern annotator using vLLM. Supports DeepSeek-R1 and other HF models."""

    def __init__(
        self,
        model_path: str,
        max_new_tokens: int = 2048,
        temperature: float = 0.0,
        top_p: float = 1.0,
        top_k: int = -1,
        gpu_memory_utilization: float = 0.90,
        tensor_parallel_size: int = 1,
        seed: int = 42
    ):
        """Initialize model and decoding parameters."""
        self.llm = LLM(
            model=model_path,
            tokenizer=model_path,
            trust_remote_code=True,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_num_seqs=64
        )
        self.sampling = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_new_tokens,
            seed=seed
        )
        self.tokenizer = self.llm.get_tokenizer()

    def _build_prompt(self, cot_text: str) -> str:
        """Build prompt via explicit .replace on the base template."""
        injected = PATTERN_PROMPT_BASE.replace("{extracted_responses}", cot_text or "")
        # Prefer chat template if available to match chat-style models.
        if hasattr(self.tokenizer, "apply_chat_template"):
            try:
                messages = [{"role": "user", "content": injected}]
                rendered = self.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
                if isinstance(rendered, str) and rendered.strip():
                    return rendered
            except Exception:
                pass
        return injected

    def annotate_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Generate pattern annotations for a batch of input records."""
        prompts = [self._build_prompt(rec.get("cot", "")) for rec in batch]
        outputs = self.llm.generate(prompts, self.sampling)

        results: List[Dict[str, Any]] = []
        for out in outputs:
            text = out.outputs[0].text if out.outputs else ""
            parsed = extract_json_from_text(text)
            results.append({"raw": text.strip(), "parsed": parsed})
        return results


# -------------------------
# Processing pipeline
# -------------------------
def process_file(
    input_path: str,
    output_path: str,
    annotator: PatternAnnotator,
    batch_size: int
):
    """Stream over input JSONL, annotate patterns in batches, and write JSONL output."""
    buffer: List[Dict[str, Any]] = []
    total = 0

    def flush():
        nonlocal buffer, total
        if not buffer:
            return
        anns = annotator.annotate_batch(buffer)
        out_records: List[Dict[str, Any]] = []
        for rec, ann in zip(buffer, anns):
            pattern_value: Any = ann["parsed"] if ann["parsed"] is not None else ann["raw"]
            out_records.append({
                "identity": rec.get("identity"),
                "question": rec.get("question"),
                "question_type": rec.get("question_type"),
                "answer": rec.get("answer"),
                "cot": rec.get("cot"),
                "pattern": pattern_value
            })
        write_jsonl(output_path, out_records)
        total += len(out_records)
        buffer.clear()

    for rec in read_jsonl(input_path):
        if "cot" not in rec:
            continue
        buffer.append(rec)
        if len(buffer) >= batch_size:
            flush()

    flush()
    sys.stderr.write(f"[info] Wrote {total} records -> {output_path}\n")


# -------------------------
# CLI
# -------------------------
def parse_args() -> argparse.Namespace:
    """Parse command-line arguments."""
    p = argparse.ArgumentParser(
        description="Annotate CoT patterns in JSONL using vLLM (supports DeepSeek-R1).",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    p.add_argument("--input_path", type=str, required=True, help="Input JSONL file.")
    p.add_argument("--output_path", type=str, required=True, help="Output JSONL file.")
    p.add_argument("--model_path", type=str, required=True,
                   help="HF repo id or local model path, e.g., deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.")

    # Decoding and batching config
    p.add_argument("--batch_size", type=int, default=64, help="Batch size for generation.")
    p.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum new tokens.")
    p.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
    p.add_argument("--top_p", type=float, default=1.0, help="Top-p nucleus sampling.")
    p.add_argument("--top_k", type=int, default=-1, help="Top-k sampling cutoff.")

    # Performance and parallelism
    p.add_argument("--gpu_memory_utilization", type=float, default=0.90, help="GPU memory fraction to use.")
    p.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel degree.")

    # Reproducibility
    p.add_argument("--seed", type=int, default=42, help="Random seed for decoding.")

    return p.parse_args()


def main():
    args = parse_args()

    annotator = PatternAnnotator(
        model_path=args.model_path,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed
    )

    process_file(
        input_path=args.input_path,
        output_path=args.output_path,
        annotator=annotator,
        batch_size=args.batch_size
    )


if __name__ == "__main__":
    main()
