import os
import json
import random
import argparse
from typing import Iterable, Set, List, Tuple, Dict, Any

from tqdm import tqdm
from openai import OpenAI


PERTURBATIONS: List[Dict[str, str]] = [
    {
        "name": "Numerical Substitution",
        "description": "Change some numerical data while minimizing textual edits, ensuring the question remains valid."
    },
    {
        "name": "Digit Expansion",
        "description": "Increase the number of digits of some numerical values while minimizing textual edits, ensuring the question remains valid."
    },
    {
        "name": "Integer-decimal-fraction Conversion",
        "description": "Convert some integers into decimal or fractional forms while keeping text mostly unchanged and the question valid."
    },
    {
        "name": "Adding Operation",
        "description": "Add extra statements that increase the number of reasoning steps; allowed ops: +, -, ×, ÷."
    },
    {
        "name": "Reversing Operation",
        "description": "Turn the original required answer into a known condition and transform one known variable into the new desired answer without adding constraints, yielding a different required answer."
    },
    {
        "name": "Problem Understanding",
        "description": "Rewrite with different wording or sentence structures without changing the original solution."
    },
    {
        "name": "Distractor Insertion",
        "description": "Insert distracting but irrelevant conditions (preferably with numbers) while keeping the final answer identical."
    },
    {
        "name": "Critical Thinking",
        "description": "Remove a crucial condition so the rewritten problem has no valid answer due to missing constraints."
    },
]

SYSTEM_PROMPT_TEMPLATE = (
    "You are a helpful assistant and good at following instructions.\n"
    "Your objective is to rewrite a given math question and give an answer using the specified perturbation strategy ({PerturbationName}).\n"
    "The rewritten question should be reasonable, understandable, and answerable.\n\n"
    "Perturbation strategy: {PerturbationDescription}\n"
    "The given question: {SeedQuestion}\n"
    "Answer of the given question: {AnswerRationale}\n\n"
    "Please rewrite the question using the specified perturbation strategies while minimizing edits to avoid significant deviation in the question content.\n"
    "Ensure the rewritten question has only one required numerical answer.\n\n"
    "The rewritten question and rewritten answer:"
)


def generate_response_api(client: OpenAI, question: str, answer: str, perturbation_name: str, perturbation_description: str, model_name: str) -> str:
    system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
        PerturbationName=perturbation_name,
        PerturbationDescription=perturbation_description,
        SeedQuestion=question,
        AnswerRationale=answer,
    )
    resp = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question},
        ],
        stream=False,
    )
    return resp.choices[0].message.content.strip()


def _safe_load_json(path: str):
    if not os.path.exists(path):
        return None
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except json.JSONDecodeError:
        return None


def load_existing_aug_keys(output_file: str) -> Set[str]:
    data = _safe_load_json(output_file)
    if not data:
        return set()
    keys: Set[str] = set()
    for item in data:
        if item.get("aug_id"):
            keys.add(item["aug_id"])
        else:
            k = f"{item.get('id')}::{item.get('perturbation_name', '')}".strip()
            keys.add(k)
    return keys


def append_to_json(output_file: str, record: Dict[str, Any]) -> None:
    data = _safe_load_json(output_file)
    if data is None:
        data = []
    data.append(record)
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def _parse_only_ids(value: str) -> Set[int]:
    ids: Set[int] = set()
    for token in value.split(","):
        token = token.strip()
        if token:
            try:
                ids.add(int(token))
            except ValueError:
                pass
    return ids


def inference_api_with_resume(args):
    api_key = os.environ.get("OPENAI_API_KEY")
    base_url = os.environ.get("OPENAI_BASE_URL")
    model_name = os.environ.get("OPENAI_MODEL", "qwen3-8b")
    timeout_s = int(os.environ.get("OPENAI_TIMEOUT", "120"))

    client_kwargs: Dict[str, Any] = {"timeout": timeout_s}
    if api_key:
        client_kwargs["api_key"] = api_key
    if base_url:
        client_kwargs["base_url"] = base_url
    client = OpenAI(**client_kwargs)

    with open(args.data_path, "r", encoding="utf-8") as f:
        dataset = json.load(f)

    finished_aug_ids = load_existing_aug_keys(args.output_file)

    if args.only_ids:
        target_ids = set(args.only_ids)
    else:
        env_only = os.environ.get("ONLY_IDS", "")
        target_ids = _parse_only_ids(env_only) if env_only else set()

    for item in tqdm(dataset, desc="examples"):
        item_id = int(item.get("id"))
        if target_ids and item_id not in target_ids:
            continue

        q = item.get("question", "")
        a = item.get("answer", "")

        remaining: List[Tuple[Dict[str, str], str]] = []
        for p in PERTURBATIONS:
            aug_id = f"{item_id}::{p['name']}"
            if aug_id not in finished_aug_ids:
                remaining.append((p, aug_id))

        if not remaining:
            continue

        p, aug_id = random.choice(remaining)

        try:
            pred = generate_response_api(
                client=client,
                question=q,
                answer=a,
                perturbation_name=p["name"],
                perturbation_description=p["description"],
                model_name=model_name,
            )
        except Exception as e:
            pred = f"[ERROR] {e}"

        record = {
            "aug_id": aug_id,
            "id": item_id,
            "query": q,
            "answer": a,
            "perturbation_name": p["name"],
            "perturbation_description": p["description"],
            "prediction": pred,
        }
        append_to_json(args.output_file, record)


def retry(args):
    api_key = os.environ.get("OPENAI_API_KEY")
    base_url = os.environ.get("OPENAI_BASE_URL")
    model_name = os.environ.get("OPENAI_MODEL", "qwen3-8b")
    timeout_s = int(os.environ.get("OPENAI_TIMEOUT", "480"))

    client_kwargs: Dict[str, Any] = {"timeout": timeout_s}
    if api_key:
        client_kwargs["api_key"] = api_key
    if base_url:
        client_kwargs["base_url"] = base_url
    client = OpenAI(**client_kwargs)

    data = _safe_load_json(args.output_file)
    if data is None:
        print("No output file to retry.")
        return


    target_ids = set()

    ret = []
    for item in tqdm(data, desc="retry"):
        item_id = int(item.get("id"))
        if target_ids and item_id not in target_ids:
            ret.append(item)
            continue
        try:
            new_pred = generate_response_api(
                client=client,
                question=item.get("query", ""),
                answer=item.get("answer", ""),
                perturbation_name=item.get("perturbation_name", ""),
                perturbation_description=item.get("perturbation_description", ""),
                model_name=model_name,
            )
        except Exception as e:
            new_pred = f"[ERROR] {e}"
        item["prediction"] = new_pred
        ret.append(item)

    with open(args.output_file, "w", encoding="utf-8") as f:
        json.dump(ret, f, ensure_ascii=False, indent=2)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", required=True)
    parser.add_argument("--output-file", required=True)
    parser.add_argument("--mode", choices=["augment", "retry"], default="augment")
    args = parser.parse_args()

    if args.mode == "augment":
        inference_api_with_resume(args)
    else:
        retry(args)


if __name__ == "__main__":
    main()
