#!/usr/bin/env python
"""
convert_to_chat_jsonl.py

Turn the raw motion-forecasting dataset into chat-style JSONL
(train_chat.jsonl, val_chat.jsonl, test_chat.jsonl).

Usage:
    python convert_to_chat_jsonl.py \
        --ds_root /path/to/ds \
        --prompt_file /path/to/prompt.txt \
        --out_dir /path/to/ds          # default: same as ds_root
"""
import argparse, json, os, sys
from pathlib import Path
from tqdm import tqdm


def make_chat_line(scene_id: str, hist_obj: dict, fut_obj: dict, system_prompt: str) -> str:
    """Return one JSONL line (already json.dumps-ed)."""
    return json.dumps(
        {
            "scene_id": scene_id,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": json.dumps(hist_obj, separators=(",", ":"))},
                {"role": "assistant", "content": json.dumps(fut_obj, separators=(",", ":"))},
            ],
        }
    )


def convert_split(split_name: str, ds_root: Path, prompt: str, out_dir: Path):
    """Process one of (train|val|test)."""
    hist_dir = ds_root / split_name / "history"
    fut_dir = ds_root / split_name / "forecast_gt"
    out_path = out_dir / f"{split_name}.jsonl"

    if not hist_dir.exists() or not fut_dir.exists():
        print(f"[WARN] {split_name} split missing expected sub-folders – skipped.")
        return

    scene_files = sorted(f.name for f in hist_dir.glob("*.json"))
    if not scene_files:
        print(f"[WARN] no *.json files in {hist_dir} – skipped.")
        return

    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w") as fout:
        for fname in tqdm(scene_files, desc=f"[{split_name}]"):
            scene_id = fname[:-5]  # drop ".json"
            hist_path = hist_dir / fname
            fut_path = fut_dir / fname
            if not fut_path.exists():
                print(f"  > forecast_gt missing for {scene_id} – skipped.")
                continue
            # load both JSONs
            try:
                hist_obj = json.load(hist_path.open())
                fut_obj = json.load(fut_path.open())
            except Exception as e:
                print(f"  > ERROR reading {scene_id}: {e} – skipped.")
                continue

            fout.write(make_chat_line(scene_id, hist_obj, fut_obj, prompt) + "\n")

    print(f"✓ Wrote {out_path} ({out_path.stat().st_size/1e6:.2f} MB)")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ds_root", required=True, type=Path, help="root dataset folder (/path/to/ds)")
    ap.add_argument("--prompt_file", required=True, type=Path, help="file containing constant instruction")
    ap.add_argument("--out_dir", type=Path, default=None, help="where to put *.jsonl (default = ds_root)")
    args = ap.parse_args()

    if args.out_dir is None:
        args.out_dir = args.ds_root

    if not args.prompt_file.exists():
        sys.exit(f"Prompt file {args.prompt_file} not found.")

    system_prompt = args.prompt_file.read_text().strip()

    # for split in ["test_mini"]:
    for split in ["train", "val", "test"]:
        convert_split(split, args.ds_root, system_prompt, args.out_dir)


if __name__ == "__main__":
    main()
