#!/usr/bin/env python3
"""
Convert NHANES parquet splits to SPR-RAFT JSONL format.

Expected outputs are JSONL with ShareGPT-style conversations:
{
  "id": "...",
  "conversations": [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}],
  "y": float,
  "event": int (optional),
  "time": float (optional)
}
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Tuple


def infer_label_columns(columns: Sequence[str]) -> Tuple[Optional[str], Optional[str], Optional[str]]:
    """Infer target/event/time columns from common NHANES naming conventions."""
    target_candidates = [
        "bio_age",
        "biological_age",
        "bioage",
        "phenotypic_age",
        "pheno_age",
        "age",
        "chronological_age",
    ]
    event_candidates = [
        "mortality_event",
        "mortstat",
        "death",
        "death_event",
        "event",
    ]
    time_candidates = [
        "mortality_followup_months",
        "mortality_followup_days",
        "permth_exm",
        "permth_int",
        "followup_months",
        "followup_days",
        "time_to_event",
    ]

    target = next((c for c in target_candidates if c in columns), None)
    event = next((c for c in event_candidates if c in columns), None)
    time = next((c for c in time_candidates if c in columns), None)
    return target, event, time


def load_feature_names(metadata_path: Path) -> Optional[List[str]]:
    if not metadata_path.exists():
        return None
    with metadata_path.open("r") as f:
        payload = json.load(f)
    return payload.get("feature_names")


def format_value(value: float, decimals: int) -> str:
    return f"{value:.{decimals}f}"


def build_prompt(feature_names: Sequence[str], row, value_decimals: int) -> str:
    parts = []
    for name in feature_names:
        value = row.get(name)
        if value is None:
            rendered = "NA"
        else:
            try:
                rendered = format_value(float(value), value_decimals)
            except (TypeError, ValueError):
                rendered = "NA"
        parts.append(f"{name}={rendered}")

    features_text = ", ".join(parts)
    return (
        "Patient record (standardized values):\n"
        f"{features_text}\n\n"
        "Predict biological age in years. Output a single number."
    )


def build_response(y_value: float, decimals: int) -> str:
    return f"[REG] {format_value(y_value, decimals)}"


def convert_split(
    df,
    split: str,
    feature_names: Sequence[str],
    target_col: str,
    event_col: Optional[str],
    time_col: Optional[str],
    decimals: int,
    max_samples: Optional[int],
    seed: int,
):
    if max_samples is not None and max_samples < len(df):
        df = df.sample(n=max_samples, random_state=seed).reset_index(drop=True)
    else:
        df = df.reset_index(drop=True)

    samples = []
    for idx, row in df.iterrows():
        y_value = row[target_col]
        prompt = build_prompt(feature_names, row, value_decimals=4)
        response = build_response(y_value, decimals=decimals)

        sample = {
            "id": f"NHANES-{split}-{idx}",
            "conversations": [
                {"from": "human", "value": prompt},
                {"from": "gpt", "value": response},
            ],
            "y": float(y_value),
        }

        if event_col and event_col in row:
            sample["event"] = int(row[event_col])
        if time_col and time_col in row:
            sample["time"] = float(row[time_col])

        samples.append(sample)

    return samples


def main() -> None:
    parser = argparse.ArgumentParser(description="Convert NHANES parquet to SPR-RAFT JSONL.")
    parser.add_argument("--data_dir", type=str, default="dataset/NHANES", help="NHANES dataset directory")
    parser.add_argument("--output_dir", type=str, default="data/spr_raft_nhanes", help="Output directory")
    parser.add_argument("--splits", nargs="+", default=["train", "val", "test"], help="Splits to convert")
    parser.add_argument("--target_col", type=str, default=None, help="Target column name (biological age)")
    parser.add_argument("--event_col", type=str, default=None, help="Event column name for mortality")
    parser.add_argument("--time_col", type=str, default=None, help="Time-to-event column name")
    parser.add_argument("--decimals", type=int, default=1, help="Decimals for y rendering")
    parser.add_argument("--max_samples", type=int, default=None, help="Optional max samples per split")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
    args = parser.parse_args()

    try:
        import pandas as pd  # pylint: disable=import-error
    except ImportError as exc:
        raise SystemExit(
            "pandas is required to read parquet files. Install with `uv pip install pandas pyarrow`."
        ) from exc

    data_dir = Path(args.data_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    metadata_path = data_dir / "preprocessing_metadata.json"
    feature_names = load_feature_names(metadata_path)

    for split in args.splits:
        split_path = data_dir / f"{split}.parquet"
        if not split_path.exists():
            raise SystemExit(f"Missing split file: {split_path}")

        df = pd.read_parquet(split_path)
        if feature_names is None:
            feature_names = [c for c in df.columns]

        target_col = args.target_col
        event_col = args.event_col
        time_col = args.time_col

        if target_col is None or target_col not in df.columns:
            inferred_target, inferred_event, inferred_time = infer_label_columns(df.columns)
            target_col = target_col or inferred_target
            event_col = event_col or inferred_event
            time_col = time_col or inferred_time

        if target_col is None or target_col not in df.columns:
            raise SystemExit(
                "Could not infer target column. Use --target_col. "
                f"Available columns: {list(df.columns)[:30]}..."
            )

        feature_columns = [c for c in feature_names if c in df.columns and c not in {target_col, event_col, time_col}]

        samples = convert_split(
            df=df,
            split=split,
            feature_names=feature_columns,
            target_col=target_col,
            event_col=event_col,
            time_col=time_col,
            decimals=args.decimals,
            max_samples=args.max_samples,
            seed=args.seed,
        )

        output_path = output_dir / f"{split}.jsonl"
        with output_path.open("w") as f:
            for sample in samples:
                f.write(json.dumps(sample, ensure_ascii=True) + "\n")

        print(f"Wrote {len(samples)} samples to {output_path}")


if __name__ == "__main__":
    main()
