#!/usr/bin/env python
import argparse
from tqdm import tqdm
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

from vllm import LLM, SamplingParams
from utils.utils import read_jsonl, write_jsonl

DEFAULT_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"

SYSTEM_PROMPT = """You are a strict answer parser."""

TASK_PROMPTS = {
    "gsm": """Instruction:
We have a user's question and a model's generated response:

Your task:
1. Carefully read the question and the generated response in **Example 6 only**.
2. Extract the final answer based on the following rules:
    - If the response contains a number (with or without units), **extract only the numeric value**.
    - If the response is purely textual (no numbers), **extract the exact string as it appears**.
3. Use the following output format:
    - **Model's Final Answer is:** [Your extracted answer]

Rules:
- Only process **Example 6** for extraction. Ignore all other examples.
- Do not include units, symbols, or extra text when extracting numbers.
- Provide the answer strictly in the requested format without additional explanations.

### Examples

#### Example 1:
- Model's Generated Response: It takes about 160 minutes.
- **Model's Final Answer is:** 160

#### Example 2:
- Model's Generated Response: The nearest star is approximately 4.24 light years away.
- **Model's Final Answer is:** 4.24

#### Example 3:
- Model's Generated Response: The tallest mountain is Mount Everest.
- **Model's Final Answer is:** Mount Everest

#### Example 4:
- Model's Generated Response: It weighs 5 kg.
- **Model's Final Answer is:** 5

#### Example 5:
- Model's Generated Response: 81 + 221 - 24 = 278.
- **Model's Final Answer is:** 278

### Example 6:
- Model's Generated Response: <model_answer>""",

    "math": """Instruction: 
You are given the true answer and the final answer generated by a model for a math problem.

Your task:
1. Only examine **Example 6**.
2. Compare the **model's final answer** and the **true answer**.
3. Respond with "yes" if they exactly match, otherwise respond with "no".
4. Do not include any explanation or extra words — just respond with "yes" or "no".

Example 1:  
True Answer: 0.5  
Model Answer: 1/2  
Is it correct?: yes

Example 2:  
True Answer: 18  
Model Answer: 22  
Is it correct?: no

Example 3:  
True Answer: \\dfrac{3}{15}  
Model Answer: 18 / 90  
Is it correct?: yes

Example 4:  
True Answer: \frac{5}{2} 
Model Answer: \frac{15}{8}  
Is it correct?: no

Example 5:  
True Answer: 162  
Model Answer: 150 + 12 = 162
Is it correct?: yes

Example 6:  
True Answer: <true_answer>  
Model Answer: <model_answer>  
Is it correct?: """,

    "coding": """Extract the final answer for a coding task.
- If the gold/predicted answer includes code, keep only the code block contents.
- If it includes a final output/result, keep only that output.
- If both appear, prefer the final output.
Return JSON with keys: parsed_gold, parsed_pred.
""",
}


@dataclass
class Record:
    question: str
    gold_answer: Optional[str]
    predicted_answer: Optional[str]
    task_type: str


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Parse answers with vLLM.")
    parser.add_argument("--input_file", required=True)
    parser.add_argument("--output_file", required=True)
    parser.add_argument("--model_name_or_path", default=DEFAULT_MODEL)
    parser.add_argument("--task_type", default=None,
                        choices=list(TASK_PROMPTS.keys()) + [None])
    parser.add_argument("--tensor_parallel_size", type=int, default=1)
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.9)
    parser.add_argument("--max_tokens", type=int, default=10)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=32)
    return parser.parse_args()

def extract_first(record: Dict[str, Any], keys: Iterable[str]) -> Optional[str]:
    for key in keys:
        if key in record:
            return record[key]
    return None

def build_record(record: Dict[str, Any], task_type: Optional[str]) -> Record:
    question = extract_first(record, ["question", "prompt", "input"]) or ""
    gold = extract_first(record, ["true_answer", "answer", "reference", "label"])
    pred = extract_first(record, ["completion", "predicted_answer", "prediction", "pred", "response", "output"])
    return Record(question=question, gold_answer=gold, predicted_answer=pred.split("<answer>")[-1],
                  task_type=task_type)


def build_user_prompt(record: Record) -> str:
    task_prompt = TASK_PROMPTS[record.task_type].replace("<model_answer>", record.predicted_answer)
    if record.task_type == "math":
        task_prompt = task_prompt.replace("<true_answer>", record.gold_answer or "")
    return task_prompt

def run_parsing(records: List[Record], args: argparse.Namespace) -> List[Dict[str, Any]]:
    llm = LLM(
        model=args.model_name_or_path,
        tensor_parallel_size=args.tensor_parallel_size,
        gpu_memory_utilization=args.gpu_memory_utilization,
    )
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
    )

    prompts = [
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": build_user_prompt(rec)},
        ]
        for rec in records
    ]

    outputs = []
    for idx in tqdm(range(0, len(prompts), args.batch_size)):
        batch = prompts[idx: idx + args.batch_size]
        outputs.extend(llm.chat(batch, sampling_params, use_tqdm=False))

    parsed_rows: List[Dict[str, Any]] = []
    for rec, output in zip(records, outputs):
        response_text = output.outputs[0].text
        parsed_rows.append(
            {
                "question": rec.question,
                "gold_answer": rec.gold_answer,
                "predicted_answer": rec.predicted_answer,
                "parsed": response_text,
            }
        )
    return parsed_rows


def main() -> None:
    args = parse_args()
    raw_records = read_jsonl(args.input_file)
    records = [build_record(row, args.task_type) for row in raw_records]
    parsed_rows = run_parsing(records, args)
    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    write_jsonl(str(output_path), parsed_rows)


if __name__ == "__main__":
    main()
