#!/usr/bin/env python3
import argparse
import json
import os
from pathlib import Path
from typing import List, Tuple


def read_jsonl(path: Path) -> List[dict]:
    items = []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                items.append(json.loads(line))
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON on line {len(items)+1} in {path}: {e}")
    return items


def extract_lm_loss(rec: dict) -> float:
    # Try common placements for lm_loss
    if isinstance(rec, dict):
        if 'lm_loss' in rec and isinstance(rec['lm_loss'], (int, float)):
            return float(rec['lm_loss'])
        metrics = rec.get('metrics') or rec.get('metric')
        if isinstance(metrics, dict):
            for k in ['lm_loss', 'lm_loss_avg', 'lm_loss_mean', 'loss']:
                if k in metrics and isinstance(metrics[k], (int, float)):
                    return float(metrics[k])
    raise KeyError("lm_loss not found in record")


def build_indices_sorted_by_lmloss_desc(sgd: List[dict]) -> List[int]:
    scored: List[Tuple[int, float]] = []
    for idx, rec in enumerate(sgd):
        try:
            loss_val = extract_lm_loss(rec)
        except Exception:
            # Treat missing/invalid loss as smallest priority (end of list)
            loss_val = float('-inf')
        scored.append((idx, loss_val))
    # high to low
    scored.sort(key=lambda x: x[1], reverse=True)
    return [i for i, _ in scored]


def write_subset(od: List[dict], indices: List[int], out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open('w', encoding='utf-8') as f:
        for i in indices:
            f.write(json.dumps(od[i], ensure_ascii=False) + "\n")


def main():
    parser = argparse.ArgumentParser(description="Curate Dolly sub-datasets by lm_loss percentiles")
    parser.add_argument('--od', required=True, help='Path to original dataset train.jsonl')
    parser.add_argument('--sgd', required=True, help='Path to student-generated dataset answers_with_metrics_*.jsonl')
    parser.add_argument('--out-base', required=True, help='Base output dir: .../gpt2_curated/stage')
    parser.add_argument('--check-prompts', action='store_true', help='Verify prompts align when available')
    args = parser.parse_args()

    od_path = Path(args.od)
    sgd_path = Path(args.sgd)
    out_base = Path(args.out_base)

    od = read_jsonl(od_path)
    sgd = read_jsonl(sgd_path)

    if len(od) != len(sgd):
        raise ValueError(f"Length mismatch: OD={len(od)} vs SGD={len(sgd)}")

    if args.check_prompts:
        for i, (a, b) in enumerate(zip(od, sgd)):
            pa = a.get('prompt') or a.get('instruction')
            pb = b.get('prompt') or b.get('instruction')
            if pa is not None and pb is not None and pa != pb:
                raise ValueError(f"Prompt mismatch at index {i}")

    order = build_indices_sorted_by_lmloss_desc(sgd)
    n = len(order)
    k75 = int(round(n * 0.75))
    k50 = int(round(n * 0.50))
    k25 = int(round(n * 0.25))

    idx_sub1 = order[:k75]
    idx_sub2 = order[:k50]
    idx_sub3 = order[:k25]

    sub_specs = [
        ("sub1", idx_sub1),
        ("sub2", idx_sub2),
        ("sub3", idx_sub3),
    ]

    stats_lines = []
    for name, idxs in sub_specs:
        out_dir = out_base / name / 'gpt2'
        out_file = out_dir / 'train.jsonl'
        write_subset(od, idxs, out_file)
        stats_lines.append(f"{name}: {len(idxs)} items -> {out_file}")

    stats_path = out_base / 'stage_stats.txt'
    stats_path.write_text("\n".join([
        f"OD: {len(od)} | SGD: {len(sgd)}",
        f"thresholds (counts): 75%={k75}, 50%={k50}, 25%={k25}",
        *stats_lines
    ]) + "\n", encoding='utf-8')

    print("\n".join([
        "[done] Wrote sub-datasets by lm_loss:",
        *stats_lines,
        f"Stats -> {stats_path}",
    ]))


if __name__ == '__main__':
    main()


