"""Convert DSR teacher puzzle parquets to lm_eval-style samples JSONL.

Output filenames match the pattern the existing pipeline expects:
  samples_<task>_<YYYY-MM-DD>T<HH-MM-SS>.<microseconds>.jsonl

This lets us run `python -m analysis.exploration.pipeline` over DSR teacher
traces using the same code path as v90_sft_puzzles, with no new aggregation.
"""

import json
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd


def _to_jsonable(v):
    """Convert numpy arrays to nested lists for JSON serialization."""
    if isinstance(v, np.ndarray):
        return v.tolist()
    if isinstance(v, (list, tuple)):
        return [_to_jsonable(x) for x in v]
    if isinstance(v, dict):
        return {k: _to_jsonable(x) for k, x in v.items()}
    if isinstance(v, (np.integer,)):
        return int(v)
    if isinstance(v, (np.floating,)):
        return float(v)
    return v

DSR_FILES = {
    "bridges_5x5de_dsr": "data/dsr_5pct_truncated/bridges_5x5de_dsr_intformat_json.parquet",
    "galaxies_3x3de_dsr": "data/dsr_5pct_truncated/galaxies_3x3de_dsr_intformat_json.parquet",
    "galaxies_4x4de_dsr": "data/dsr_5pct_truncated/galaxies_4x4de_dsr_intformat_json.parquet",
    "pattern_3x3_dsr": "data/dsr_5pct_truncated/pattern_3x3_dsr.parquet",
    "pattern_4x4_dsr": "data/dsr_5pct_truncated/pattern_4x4_dsr.parquet",
    "undead_3x3de_dsr": "data/dsr_5pct_truncated/undead_3x3de_dsr.parquet",
}

OUT = Path("data/dsr_5pct_truncated/lm_eval_jsonl")
TS = datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")


def main():
    OUT.mkdir(exist_ok=True, parents=True)
    for task, src in DSR_FILES.items():
        df = pd.read_parquet(src)
        out_path = OUT / f"samples_{task}_{TS}.jsonl"
        with open(out_path, "w") as f:
            for i, row in df.iterrows():
                reasoning = row.get("reasoning_full") or ""
                answer = str(row.get("answer", ""))
                response = f"<reasoning>\n{reasoning}\n</reasoning>\n<answer>\n{answer}\n</answer>"
                rec = {
                    "doc_id": int(i),
                    "target": str(answer),
                    "resps": [[response]],
                    "filtered_resps": [response],
                    "doc": {"problem": _to_jsonable(row.get("problem")),
                            "solution": _to_jsonable(row.get("solution")),
                            "answer": _to_jsonable(row.get("answer")),
                            "filename": str(row.get("filename", "")),
                            "gridsize": str(row.get("gridsize", "")),
                            "difficulty": str(row.get("difficulty", ""))},
                }
                f.write(json.dumps(rec) + "\n")
        print(f"{task}: wrote {len(df)} traces → {out_path.name}")


if __name__ == "__main__":
    main()
