#!/usr/bin/env python3
"""
Simple Dolly curated dataset creator (option2)

Deletes ODS samples when:
 - SGS Rouge_L == 0, OR
 - TGS Rouge_L < 1

Usage:
 python scripts/gpt2/tools/data_curation_option2.py \
   --ods /hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl \
   --sgs /hy-tmp/dc/processed_data/dolly/full/gpt2-base/answers_with_metrics_gpt2-base.new.jsonl \
   --tgs /hy-tmp/dc/processed_data/dolly/full/14290/answers_with_metrics_14290.new.jsonl \
   --out /hy-tmp/dc/processed_data/dolly/full/gpt2_curated/train.jsonl \
   --log /hy-tmp/dc/processed_data/dolly/full/gpt2_curated/train_option2.stats.txt
"""
import argparse
import json
from pathlib import Path
from typing import TextIO


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--ods', type=Path, default=Path('/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl'))
    p.add_argument('--sgs', type=Path, default=Path('/hy-tmp/dc/processed_data/dolly/full/gpt2-base/answers_with_metrics_gpt2-base.new.jsonl'))
    p.add_argument('--tgs', type=Path, default=Path('/hy-tmp/dc/processed_data/dolly/full/14290/answers_with_metrics_14290.new.jsonl'))
    p.add_argument('--out', type=Path, default=Path('/hy-tmp/dc/processed_data/dolly/full/gpt2_curated/train.jsonl'))
    p.add_argument('--log', type=Path, default=Path('/hy-tmp/dc/processed_data/dolly/full/gpt2_curated/train_option2.stats.txt'))
    return p.parse_args()


def safe_float(x, default=0.0):
    try:
        return float(x)
    except Exception:
        return default


def stream(path: Path):
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)


def main():
    args = parse_args()

    ods_iter = stream(args.ods)
    sgs_iter = stream(args.sgs)
    tgs_iter = stream(args.tgs)

    out_path = args.out
    out_path.parent.mkdir(parents=True, exist_ok=True)

    kept = 0
    deleted_by_sgs = 0
    deleted_by_tgs = 0
    deleted_both = 0
    total = 0

    with out_path.open('w', encoding='utf-8') as out_f:
        for o, s, t in zip(ods_iter, sgs_iter, tgs_iter):
            total += 1
            s_val = safe_float(s.get('Rouge_L', s.get('rougeL', 0.0)), 0.0)
            t_val = safe_float(t.get('Rouge_L', t.get('rougeL', 0.0)), 0.0)

            del_by_sgs = (s_val == 0.0)
            del_by_tgs = (t_val < 1.0)

            if del_by_sgs and del_by_tgs:
                deleted_both += 1
                # sample deleted
                continue
            if del_by_sgs:
                deleted_by_sgs += 1
                continue
            if del_by_tgs:
                deleted_by_tgs += 1
                continue

            # keep sample as-is
            out_f.write(json.dumps(o, ensure_ascii=False) + '\n')
            kept += 1

    # write log
    args.log.parent.mkdir(parents=True, exist_ok=True)
    with args.log.open('w', encoding='utf-8') as lf:
        lf.write(f'total_input: {total}\n')
        lf.write(f'kept: {kept}\n')
        lf.write(f'deleted_by_sgs_only: {deleted_by_sgs}\n')
        lf.write(f'deleted_by_tgs_only: {deleted_by_tgs}\n')
        lf.write(f'deleted_by_both: {deleted_both}\n')

    print('done')
    print('total', total, 'kept', kept)


if __name__ == '__main__':
    main()


