import os
import json
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


def generate_response(model, tokenizer, question: str) -> str:
    base_prompt = (
        "You are a mathematical reasoning assistant. "
        "Now solve the following question. The final answer should be marked with \\boxed{}"
    )
    messages = [{"role": "user", "content": base_prompt + question}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=65536)
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
    content = tokenizer.decode(output_ids, skip_special_tokens=True)
    return content


def load_existing_ids(output_file: str):
    if not os.path.exists(output_file):
        return set()
    with open(output_file, "r", encoding="utf-8") as f:
        try:
            data = json.load(f)
            return set(item.get("id") for item in data if isinstance(item, dict) and "id" in item)
        except json.JSONDecodeError:
            return set()


def append_to_json(output_file: str, record: dict) -> None:
    if not os.path.exists(output_file):
        with open(output_file, "w", encoding="utf-8") as f:
            json.dump([record], f, ensure_ascii=False, indent=2)
        return
    with open(output_file, "r+", encoding="utf-8") as f:
        try:
            data = json.load(f)
        except json.JSONDecodeError:
            data = []
        data.append(record)
        f.seek(0)
        f.truncate()
        json.dump(data, f, ensure_ascii=False, indent=2)


def inference(args):
    model_path = args.model_path or os.environ.get("MODEL_PATH")
    data_path = args.data_path or os.environ.get("DATA_PATH")
    output_file = args.output_file or os.environ.get("OUTPUT_FILE")
    gpu_id = args.gpu_id if args.gpu_id is not None else int(os.environ.get("GPU_ID", "0"))

    if not model_path or not os.path.exists(model_path):
        raise FileNotFoundError(f"Model path not found: {model_path}")
    if not data_path or not os.path.exists(data_path):
        raise FileNotFoundError(f"Data path not found: {data_path}")
    if not output_file:
        raise ValueError("Output file path is required (use --output-file or set OUTPUT_FILE).")

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        torch_dtype="auto",
    )
    model.to(device)
    model.eval()

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    finished_ids = load_existing_ids(output_file)

    for item in tqdm(data):
        if item.get("id") in finished_ids:
            continue
        try:
            prediction = generate_response(model, tokenizer, item.get("question", ""))
        except Exception as e:
            prediction = f"[ERROR] {e}"
        record = {
            "id": item.get("id"),
            "query": item.get("question", ""),
            "answer": item.get("answer", ""),
            "prediction": prediction,
        }
        append_to_json(output_file, record)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path")
    parser.add_argument("--data-path")
    parser.add_argument("--output-file")
    parser.add_argument("--gpu-id", type=int)
    args = parser.parse_args()
    inference(args)
