import os
import json
import argparse
from rl.prompt import rewrite_prompt


def build_rewrite_dataset(src_path: str, dst_path: str) -> int:
    with open(src_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    ret = []
    max_len = 0
    for item in data:
        rec = {
            "instruction": rewrite_prompt.format(question=item.get("query", ""), answer=item.get("answer", "")),
            "input": "",
            "output": item.get("response", ""),
        }
        ret.append(rec)
        max_len = max(max_len, len(rec["output"]))
    with open(dst_path, "w", encoding="utf-8") as f:
        json.dump(ret, f, ensure_ascii=False, indent=4)
    return max_len


def build_ride_dataset(src_path: str, dst_path: str) -> int:
    prompt_suffix = "\nLet us think step by step and output the final answer in \\boxed{}.\n"
    with open(src_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    ret = []
    max_len = 0
    for item in data:
        rec = {
            "prompt": f"{item.get('question', '')}{prompt_suffix}",
            "answer": item.get("answer", ""),
        }
        ret.append(rec)
        max_len = max(max_len, len(rec["prompt"]))
    with open(dst_path, "w", encoding="utf-8") as f:
        json.dump(ret, f, ensure_ascii=False, indent=4)
    return max_len


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--rewrite-src", default=os.environ.get("REWRITE_SRC", ""))
    parser.add_argument("--rewrite-dst", default=os.environ.get("REWRITE_DST", ""))
    parser.add_argument("--ride-src", default=os.environ.get("RIDE_SRC", ""))
    parser.add_argument("--ride-dst", default=os.environ.get("RIDE_DST", ""))
    parser.add_argument("--mode", choices=["rewrite", "ride", "both"], default=os.environ.get("MODE", "both"))
    args = parser.parse_args()

    if args.mode in ("rewrite", "both"):
        if not args.rewrite_src or not args.rewrite_dst:
            raise ValueError("rewrite mode requires --rewrite-src and --rewrite-dst")
        ml = build_rewrite_dataset(args.rewrite_src, args.rewrite_dst)
        print(ml)

    if args.mode in ("ride", "both"):
        if not args.ride_src or not args.ride_dst:
            raise ValueError("ride mode requires --ride-src and --ride-dst")
        ml = build_ride_dataset(args.ride_src, args.ride_dst)
        print(ml)


if __name__ == "__main__":
    main()
