#!/usr/bin/env python3
"""
Combine two prompt/response JSONL datasets, shuffle, and save both the full
combined dataset and a random subset (e.g., 20K).

Assumes each input JSONL line is a JSON object containing at least:
  { "prompt": str, "response": str }

Usage example:
  python scripts/combine_and_subset_jsonl.py \
    --alpaca /data/home/Yunsheng/alignment-handbook/datasets/alpaca_dataset.jsonl \
    --smoltalk /data/home/Yunsheng/alignment-handbook/datasets/smoltalk_single_round.jsonl \
    --output-combined /data/home/Yunsheng/alignment-handbook/datasets/combined_alpaca_smoltalk_dataset.jsonl \
    --output-subset /data/home/Yunsheng/alignment-handbook/datasets/combined_subset_20K.jsonl \
    --subset-size 20000 \
    --seed 42
"""

import argparse
import json
from pathlib import Path
from typing import List, Dict
import random


def read_jsonl(path: Path) -> List[Dict]:
    data: List[Dict] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            # keep only required keys, if present
            prompt = obj.get("prompt")
            response = obj.get("response")
            if prompt is None or response is None:
                # fall back: try common alternative keys if needed
                # (no-op if not present)
                pass
            data.append({
                "prompt": prompt,
                "response": response,
                **{k: v for k, v in obj.items() if k not in {"prompt", "response"}}
            })
    return data


def write_jsonl(path: Path, rows: List[Dict]):
    with path.open("w", encoding="utf-8") as f:
        for obj in rows:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")


def main():
    parser = argparse.ArgumentParser(description="Combine and subset JSONL prompt/response datasets")
    parser.add_argument("--alpaca", type=str, required=True, help="Path to alpaca JSONL")
    parser.add_argument("--smoltalk", type=str, required=True, help="Path to smoltalk JSONL")
    parser.add_argument("--output-combined", type=str, required=True, help="Path to write combined JSONL")
    parser.add_argument("--output-subset", type=str, required=True, help="Path to write subset JSONL")
    parser.add_argument("--subset-size", type=int, default=20000, help="Subset size (default: 20000)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()

    alpaca_path = Path(args.alpaca)
    smoltalk_path = Path(args.smoltalk)
    out_combined = Path(args.output_combined)
    out_subset = Path(args.output_subset)
    out_combined.parent.mkdir(parents=True, exist_ok=True)
    out_subset.parent.mkdir(parents=True, exist_ok=True)

    # Load
    alpaca = read_jsonl(alpaca_path)
    smoltalk = read_jsonl(smoltalk_path)

    # Combine
    combined = alpaca + smoltalk

    # Shuffle
    random.seed(args.seed)
    random.shuffle(combined)

    # Write combined
    write_jsonl(out_combined, combined)

    # Subset (random sample without replacement from the already-shuffled list)
    subset_size = min(args.subset_size, len(combined))
    subset = combined[:subset_size]
    write_jsonl(out_subset, subset)

    # Simple stats
    print(f"✅ Combined: {len(alpaca)} alpaca + {len(smoltalk)} smoltalk = {len(combined)} rows")
    print(f"✅ Subset written: {len(subset)} rows -> {out_subset}")
    print(f"✅ Combined written to: {out_combined}")


if __name__ == "__main__":
    main()
