import json
import re
import os
import argparse
from tqdm import tqdm
from zhipuai import ZhipuAI


def generate_response_api(client, question, model_name):
    system_prompt = "You are a mathematical reasoning assistant. Now solve the following question. The final answer should be marked with \\boxed{}"
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        max_tokens=65536,
        stream=True,
    )
    full_text = ""
    for chunk in response:
        delta = chunk.choices[0].delta
        if hasattr(delta, "content") and delta.content:
            full_text += delta.content
    if "</think>\n" in full_text:
        final_answer = full_text.split("</think>\n", 1)[-1].strip()
    else:
        final_answer = full_text.strip()
    return final_answer


def load_existing_ids(output_file):
    if not os.path.exists(output_file):
        return set()
    with open(output_file, "r", encoding="utf-8") as f:
        try:
            data = json.load(f)
            return set(item.get("id") for item in data if isinstance(item, dict) and "id" in item)
        except json.JSONDecodeError:
            return set()


def append_to_json(output_file, record):
    if not os.path.exists(output_file):
        with open(output_file, "w", encoding="utf-8") as f:
            json.dump([record], f, ensure_ascii=False, indent=2)
    else:
        with open(output_file, "r+", encoding="utf-8") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                data = []
            data.append(record)
            f.seek(0)
            f.truncate()
            json.dump(data, f, ensure_ascii=False, indent=2)


def extract_last_boxed(text: str):
    starts = [m.end() for m in re.finditer(r'boxed\{', text)]
    if not starts:
        return None
    j = starts[-1]
    depth = 1
    k = j
    while k < len(text) and depth > 0:
        ch = text[k]
        if ch == '{':
            depth += 1
        elif ch == '}':
            depth -= 1
        k += 1
    if depth != 0:
        return None
    return text[j : k - 1].strip()


def inference_api(args):
    api_key = os.environ.get("ZHIPU_API_KEY")
    base_url = os.environ.get("ZHIPU_BASE_URL")
    timeout_s = int(os.environ.get("ZHIPU_TIMEOUT", "120"))
    model_name = os.environ.get("ZHIPU_MODEL", "glm-4.5")

    client_kwargs = {"api_key": api_key}
    if base_url:
        client_kwargs["base_url"] = base_url
    if timeout_s:
        client_kwargs["timeout"] = timeout_s
    client = ZhipuAI(**client_kwargs)

    with open(args.data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    finished_ids = load_existing_ids(args.output_file)

    for item in tqdm(data):
        if item.get("id") in finished_ids:
            continue
        try:
            prediction = generate_response_api(client, item.get("question", ""), model_name)
        except Exception as e:
            prediction = f"[ERROR] {e}"
        record = {
            "id": item.get("id"),
            "id_": item.get("id_"),
            "query": item.get("question", ""),
            "answer": item.get("answer", ""),
            "prediction": prediction,
        }
        append_to_json(args.output_file, record)


def retry(args):
    api_key = os.environ.get("ZHIPU_API_KEY")
    base_url = os.environ.get("ZHIPU_BASE_URL")
    timeout_s = int(os.environ.get("ZHIPU_TIMEOUT", "120"))
    model_name = os.environ.get("ZHIPU_MODEL", "glm-4.5")

    client_kwargs = {"api_key": api_key}
    if base_url:
        client_kwargs["base_url"] = base_url
    if timeout_s:
        client_kwargs["timeout"] = timeout_s
    client = ZhipuAI(**client_kwargs)

    ret = []
    with open(args.output_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    for item in tqdm(data):
        pred = item.get("prediction")
        need_retry = (
            pred is None
            or "[ERROR]" in str(pred)
            or extract_last_boxed(str(pred)) is None
        )
        if need_retry:
            try:
                prediction = generate_response_api(client, item.get("query", ""), model_name)
            except Exception as e:
                prediction = f"[ERROR] {e}"
            item["prediction"] = prediction
        ret.append(item)
    with open(args.output_file, "w", encoding="utf-8") as f:
        json.dump(ret, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", required=True)
    parser.add_argument("--output-file", required=True)
    parser.add_argument("--retry-first", action="store_true")
    args = parser.parse_args()
    if args.retry_first:
        retry(args)
    inference_api(args)
