#!/usr/bin/env python3
"""
Convert a SMILES regression CSV to SPR-RAFT JSONL (ShareGPT format).

Example:
    uv run python scripts/convert_molecule_csv_to_spr_raft.py \
        --input_path dataset/freesolv.csv \
        --output_dir data/spr_raft_freesolv \
        --smiles_col smiles \
        --target_col y \
        --task_description "hydration free energy (kcal/mol)" \
        --decimals 4
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Optional, Tuple


def build_prompt(smiles: str, task_description: str) -> str:
    return (
        "Molecule (SMILES):\n"
        f"{smiles}\n\n"
        f"Predict {task_description}. Output a single number."
    )


def build_response(value: float, decimals: int) -> str:
    return f"[REG] {value:.{decimals}f}"


def split_counts(total: int, ratios: Tuple[float, float, float]) -> Tuple[int, int, int]:
    train_ratio, val_ratio, _ = ratios
    train_n = int(total * train_ratio)
    val_n = int(total * val_ratio)
    test_n = total - train_n - val_n
    return train_n, val_n, test_n


def write_jsonl(output_path: Path, samples) -> None:
    with output_path.open("w") as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=True) + "\n")


def main() -> None:
    parser = argparse.ArgumentParser(description="Convert SMILES regression CSV to SPR-RAFT JSONL.")
    parser.add_argument("--input_path", type=str, required=True, help="Input CSV path")
    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
    parser.add_argument("--smiles_col", type=str, default="smiles", help="SMILES column name")
    parser.add_argument("--target_col", type=str, required=True, help="Target column name")
    parser.add_argument("--id_col", type=str, default=None, help="Optional ID column name")
    parser.add_argument("--dataset_name", type=str, default="dataset", help="Dataset name prefix for IDs")
    parser.add_argument("--task_description", type=str, required=True, help="Task description for prompt")
    parser.add_argument("--decimals", type=int, default=3, help="Decimals for y rendering")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for split")
    parser.add_argument(
        "--split_ratios",
        type=float,
        nargs=3,
        default=(0.8, 0.1, 0.1),
        help="Train/val/test ratios (must sum to 1.0)",
    )
    args = parser.parse_args()

    try:
        import pandas as pd  # pylint: disable=import-error
    except ImportError as exc:
        raise SystemExit(
            "pandas is required to read CSV files. Install with `uv pip install pandas`."
        ) from exc

    input_path = Path(args.input_path)
    if not input_path.exists():
        raise SystemExit(f"Missing input file: {input_path}")

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    df = pd.read_csv(input_path)
    for col in [args.smiles_col, args.target_col]:
        if col not in df.columns:
            raise SystemExit(f"Missing column '{col}' in {input_path}")

    df = df.dropna(subset=[args.smiles_col, args.target_col]).reset_index(drop=True)
    df = df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True)

    train_n, val_n, test_n = split_counts(len(df), tuple(args.split_ratios))
    splits = {
        "train": df.iloc[:train_n],
        "val": df.iloc[train_n:train_n + val_n],
        "test": df.iloc[train_n + val_n:train_n + val_n + test_n],
    }

    for split_name, split_df in splits.items():
        samples = []
        for idx, row in split_df.reset_index(drop=True).iterrows():
            smiles = row[args.smiles_col]
            y_value = float(row[args.target_col])
            prompt = build_prompt(smiles, args.task_description)
            response = build_response(y_value, decimals=args.decimals)

            sample_id = row[args.id_col] if args.id_col and args.id_col in row else f"{args.dataset_name}-{split_name}-{idx}"

            samples.append(
                {
                    "id": str(sample_id),
                    "conversations": [
                        {"from": "human", "value": prompt},
                        {"from": "gpt", "value": response},
                    ],
                    "y": y_value,
                }
            )

        output_path = output_dir / f"{split_name}.jsonl"
        write_jsonl(output_path, samples)
        print(f"Wrote {len(samples)} samples to {output_path}")


if __name__ == "__main__":
    main()
