import argparse
import json
import math
import os
from typing import List, Tuple


def stream_jsonl(path: str):
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)


def load_jsonl(path: str) -> List[dict]:
    return list(stream_jsonl(path))


def compute_histogram_mode(values: List[float], bins: int) -> float:
    positive_values = [v for v in values if v > 0.0]
    if not positive_values:
        return 0.0

    min_v = min(positive_values)
    max_v = max(positive_values)
    if math.isclose(min_v, max_v):
        return float(min_v)

    # Build simple histogram without numpy
    bin_edges = [min_v + (max_v - min_v) * i / bins for i in range(bins + 1)]
    counts = [0] * bins

    for v in positive_values:
        # Place v into a bin; include right edge only for the last bin
        if v == max_v:
            counts[-1] += 1
            continue
        # Compute index
        idx = int((v - min_v) / (max_v - min_v) * bins)
        if idx < 0:
            idx = 0
        if idx >= bins:
            idx = bins - 1
        counts[idx] += 1

    peak_idx = max(range(bins), key=lambda i: counts[i])
    # Return bin center as the mode estimate
    left = bin_edges[peak_idx]
    right = bin_edges[peak_idx + 1]
    return float((left + right) / 2.0)


def partition_indices(rouge_values: List[float], m: float, n: float, eps: float) -> Tuple[List[int], List[int], List[int]]:
    a_idx: List[int] = []  # Rouge_L == 0
    b_idx: List[int] = []  # 0 < Rouge_L <= m
    c_idx: List[int] = []  # m < Rouge_L <= n

    for i, v in enumerate(rouge_values):
        if abs(v) <= eps:
            a_idx.append(i)
        elif v <= m + eps:
            b_idx.append(i)
        elif v <= n + eps:
            c_idx.append(i)
        else:
            # greater than n: excluded from a/b/c
            pass
    return a_idx, b_idx, c_idx


def save_subset(od_data: List[dict], indices: List[int], out_path: str) -> int:
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    count = 0
    with open(out_path, 'w', encoding='utf-8') as f:
        for idx in indices:
            obj = od_data[idx]
            f.write(json.dumps(obj, ensure_ascii=False) + '\n')
            count += 1
    return count


def main():
    parser = argparse.ArgumentParser(description="Partition Dolly OD by SGD Rouge_L into sub1/sub2/sub3")
    parser.add_argument('--sgd-file', type=str, default='/hy-tmp/dc/processed_data/dolly/full/gpt2-base/answers_with_metrics_gpt2-base.new.jsonl', help='Path to SGD answers_with_metrics JSONL')
    parser.add_argument('--od-file', type=str, default='/hy-tmp/dc/processed_data/dolly/full/gpt2/train.jsonl', help='Path to original Dolly train.jsonl')
    parser.add_argument('--out-dir-base', type=str, default='/hy-tmp/dc/processed_data/dolly/full/gpt2_curated', help='Base directory to write sub-datasets')
    parser.add_argument('--bins', type=int, default=100, help='Number of bins for histogram mode estimation (m)')
    parser.add_argument('--epsilon', type=float, default=1e-9, help='Zero threshold for Rouge_L')
    args = parser.parse_args()

    print(f"[info] Loading SGD from: {args.sgd_file}")
    sgd = load_jsonl(args.sgd_file)
    print(f"[info] Loading OD from: {args.od_file}")
    od = load_jsonl(args.od_file)

    if len(sgd) != len(od):
        min_len = min(len(sgd), len(od))
        print(f"[warn] Length mismatch: SGD={len(sgd)} OD={len(od)}; truncating to {min_len}")
        sgd = sgd[:min_len]
        od = od[:min_len]

    # Optional sanity: check prompt alignment on a sample
    mismatches = 0
    for i in (0, len(sgd)//2, len(sgd)-1):
        try:
            ps = (sgd[i].get('prompt') or sgd[i].get('input') or '').strip()
            po = (od[i].get('prompt') or od[i].get('instruction') or od[i].get('Instruction') or '').strip()
            if ps and po and ps != po:
                mismatches += 1
        except Exception:
            continue
    if mismatches > 0:
        print(f"[warn] Detected {mismatches} prompt mismatches in spot-check; proceeding by index order as specified.")

    rouge_values: List[float] = []
    for rec in sgd:
        v = rec.get('Rouge_L', 0.0)
        try:
            v = float(v)
        except Exception:
            v = 0.0
        rouge_values.append(v)

    k = max(rouge_values) if rouge_values else 0.0
    m = compute_histogram_mode(rouge_values, bins=args.bins)
    if m <= args.epsilon:
        # ensure m > 0 if there are positives, else keep as 0
        positives = [v for v in rouge_values if v > args.epsilon]
        if positives:
            m = min(positives)
    n = (m + k) / 2.0

    a_idx, b_idx, c_idx = partition_indices(rouge_values, m, n, args.epsilon)

    # Build sub-dataset indices in original order
    sub1_idx = sorted(set(a_idx + b_idx + c_idx))
    sub2_idx = sorted(set(a_idx + b_idx))
    sub3_idx = sorted(set(a_idx))

    sub1_dir = os.path.join(args.out_dir_base, 'sub1')
    sub2_dir = os.path.join(args.out_dir_base, 'sub2')
    sub3_dir = os.path.join(args.out_dir_base, 'sub3')

    sub1_path = os.path.join(sub1_dir, 'train.jsonl')
    sub2_path = os.path.join(sub2_dir, 'train.jsonl')
    sub3_path = os.path.join(sub3_dir, 'train.jsonl')

    print(f"[info] Writing sub-datasets to: {args.out_dir_base}/sub{{1,2,3}}")
    c1 = save_subset(od, sub1_idx, sub1_path)
    c2 = save_subset(od, sub2_idx, sub2_path)
    c3 = save_subset(od, sub3_idx, sub3_path)

    # Write log file
    log_path = os.path.join(args.out_dir_base, 'partition_stats.txt')
    with open(log_path, 'w', encoding='utf-8') as f:
        f.write(f"m (mode>0) = {m}\n")
        f.write(f"k (max)    = {k}\n")
        f.write(f"n=(m+k)/2  = {n}\n")
        f.write(f"count a (Rouge_L == 0)         = {len(a_idx)}\n")
        f.write(f"count b (0 < Rouge_L <= m)     = {len(b_idx)}\n")
        f.write(f"count c (m < Rouge_L <= n)     = {len(c_idx)}\n")
        f.write(f"sub1 size (a+b+c unique order) = {c1}\n")
        f.write(f"sub2 size (a+b unique order)   = {c2}\n")
        f.write(f"sub3 size (a unique order)     = {c3}\n")
    print(f"[info] Saved stats to: {log_path}")


if __name__ == '__main__':
    main()





