#!/usr/bin/env python3
"""
Convert BIOSSES sentence-pair TSV to SPR-RAFT JSONL (ShareGPT format).

Example:
    uv run python scripts/convert_biosses_to_spr_raft.py \
        --input_path dataset/annotation_pairs_scores.tsv \
        --output_dir data/spr_raft_biosses \
        --decimals 2
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Iterable, Tuple


def build_prompt(sentence_1: str, sentence_2: str) -> str:
    return (
        "Sentence pair:\n"
        f"Sentence 1: {sentence_1}\n"
        f"Sentence 2: {sentence_2}\n\n"
        "Predict semantic similarity score (0-4). Output a single number."
    )


def build_response(value: float, decimals: int) -> str:
    return f"[REG] {value:.{decimals}f}"


def split_counts(total: int, ratios: Tuple[float, float, float]) -> Tuple[int, int, int]:
    train_ratio, val_ratio, _ = ratios
    train_n = int(total * train_ratio)
    val_n = int(total * val_ratio)
    test_n = total - train_n - val_n
    return train_n, val_n, test_n


def write_jsonl(output_path: Path, samples) -> None:
    with output_path.open("w") as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=True) + "\n")


def mean_score(row, score_cols: Iterable[str]) -> float:
    values = [float(row[col]) for col in score_cols if col in row]
    if not values:
        raise ValueError("No annotator scores found for row.")
    return sum(values) / len(values)


def main() -> None:
    parser = argparse.ArgumentParser(description="Convert BIOSSES TSV to SPR-RAFT JSONL.")
    parser.add_argument("--input_path", type=str, required=True, help="Input TSV path")
    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
    parser.add_argument("--sentence_1_col", type=str, default="sentence_1", help="Sentence 1 column")
    parser.add_argument("--sentence_2_col", type=str, default="sentence_2", help="Sentence 2 column")
    parser.add_argument(
        "--score_cols",
        type=str,
        nargs="+",
        default=["annotator_a", "annotator_b", "annotator_c", "annotator_d", "annotator_e"],
        help="Annotator score columns",
    )
    parser.add_argument("--dataset_name", type=str, default="biosses", help="Dataset name prefix for IDs")
    parser.add_argument("--decimals", type=int, default=2, help="Decimals for y rendering")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for split")
    parser.add_argument(
        "--split_ratios",
        type=float,
        nargs=3,
        default=(0.8, 0.1, 0.1),
        help="Train/val/test ratios (must sum to 1.0)",
    )
    args = parser.parse_args()

    try:
        import pandas as pd  # pylint: disable=import-error
    except ImportError as exc:
        raise SystemExit(
            "pandas is required to read TSV files. Install with `uv pip install pandas`."
        ) from exc

    input_path = Path(args.input_path)
    if not input_path.exists():
        raise SystemExit(f"Missing input file: {input_path}")

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    df = pd.read_csv(input_path, sep="\t")
    for col in [args.sentence_1_col, args.sentence_2_col]:
        if col not in df.columns:
            raise SystemExit(f"Missing column '{col}' in {input_path}")

    df = df.dropna(subset=[args.sentence_1_col, args.sentence_2_col]).reset_index(drop=True)
    df = df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True)

    train_n, val_n, test_n = split_counts(len(df), tuple(args.split_ratios))
    splits = {
        "train": df.iloc[:train_n],
        "val": df.iloc[train_n:train_n + val_n],
        "test": df.iloc[train_n + val_n:train_n + val_n + test_n],
    }

    for split_name, split_df in splits.items():
        samples = []
        for idx, row in split_df.reset_index(drop=True).iterrows():
            sentence_1 = row[args.sentence_1_col]
            sentence_2 = row[args.sentence_2_col]
            y_value = mean_score(row, args.score_cols)
            prompt = build_prompt(sentence_1, sentence_2)
            response = build_response(y_value, decimals=args.decimals)

            sample_id = f"{args.dataset_name}-{split_name}-{idx}"

            samples.append(
                {
                    "id": sample_id,
                    "conversations": [
                        {"from": "human", "value": prompt},
                        {"from": "gpt", "value": response},
                    ],
                    "y": float(y_value),
                }
            )

        output_path = output_dir / f"{split_name}.jsonl"
        write_jsonl(output_path, samples)
        print(f"Wrote {len(samples)} samples to {output_path}")


if __name__ == "__main__":
    main()
