# -*- coding: utf-8 -*-
import argparse
import json
import random
from typing import Any, Dict, List, Tuple
from pipeline import run_pipeline  

def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(
        description="Per-level streaming generation (premise' & hypothesis'); wiki gating; facts extraction; NLI filtering; neutral max-clique; parallel; append-friendly."
    )

    ap.add_argument("--input_roots_json", type=str, required=True)
    ap.add_argument("--total_roots", type=int, default=50000)
    ap.add_argument("--batch_size", type=int, default=1000)
    ap.add_argument("--shuffle", type=lambda x: str(x).lower() != "false", default=False)
    ap.add_argument("--start_offset", type=int, default=0)

    ap.add_argument("--depth", type=int, default=2)
    ap.add_argument("--wiki_lang", type=str, default="en")
    ap.add_argument("--wiki_k_pages", type=int, default=20)
    ap.add_argument("--wiki_sent_max", type=int, default=2)
    ap.add_argument("--output_json", type=str, default="gen_detail_fromjson.json")
    ap.add_argument("--output_new",  type=str, default="gen_samples_fromjson.json")

    # NLI
    ap.add_argument("--nli_model_name", type=str, default="cross-encoder/nli-deberta-v3-base")
    ap.add_argument("--nli_batch_size", type=int, default=128)


    ap.add_argument("--sim_threshold", type=float, default=0.30)
    ap.add_argument("--max_snippets", type=int, default=0)
    ap.add_argument("--fallback_top_m", type=int, default=3)
    ap.add_argument("--wiki_workers", type=int, default=8)

    ap.add_argument("--facts_k", type=int, default=6)
    ap.add_argument("--facts_k_fuse_premise", type=int, default=4)
    ap.add_argument("--max_facts_extract", type=int, default=10)

    ap.add_argument("--api_workers", type=int, default=16)
    ap.add_argument("--write_incremental", type=lambda x: str(x).lower() != "false", default=True)
    ap.add_argument("--truncate_existing_jsonl", type=lambda x: str(x).lower() != "false", default=False)

    return ap.parse_args()

def main():
    args = parse_args()
    random.seed(42)

    total = int(args.total_roots)
    bsz   = int(args.batch_size)
    if total <= 0 or bsz <= 0:
        raise ValueError("--total_roots --batch_size must be positive integers")

    num_batches = (total + bsz - 1) // bsz
    print(f"[batched] total_roots={total} | batch_size={bsz} | num_batches={num_batches} | shuffle={args.shuffle} | start_offset={args.start_offset}")
    print(f"[append-mode] truncate_existing_jsonl_on_first_batch={bool(args.truncate_existing_jsonl)}")

    all_rows: List[Dict[str, Any]] = []
    all_flat: List[Dict[str, Any]] = []
    total_stats: Dict[str, int] = {"ok":0, "dropped_disqualified":0, "no_wiki":0, "below_threshold":0, "facts_empty":0}

    for bi in range(num_batches):
        take = min(bsz, total - bi*bsz)
        if take <= 0:
            break

        offset = int(args.start_offset) + bi * bsz
        print(f"\n========== [Batch {bi+1}/{num_batches}] offset={offset} take={take} ==========")

        truncate_jsonl = (bi == 0) and bool(args.truncate_existing_jsonl)

        rows, flat, stats = run_pipeline(
            input_roots_json=args.input_roots_json,
            roots=take,
            offset=offset,
            shuffle=args.shuffle,
            depth=args.depth,
            wiki_lang=args.wiki_lang,
            wiki_k_pages=args.wiki_k_pages,
            wiki_sent_max=args.wiki_sent_max,
            output_json=args.output_json,
            output_new=args.output_new,
            nli_model_name=args.nli_model_name,
            nli_batch_size=args.nli_batch_size,
            sim_threshold=args.sim_threshold,
            max_snippets=args.max_snippets,
            fallback_top_m=args.fallback_top_m,
            wiki_workers=args.wiki_workers,
            facts_k=args.facts_k,
            facts_k_fuse_premise=args.facts_k_fuse_premise,
            max_facts_extract=args.max_facts_extract,
            api_workers=args.api_workers,
            write_incremental=args.write_incremental,
            truncate_existing_jsonl=truncate_jsonl,
            acc_rows=all_rows,   
            acc_flat=all_flat,
        )

        for k,v in stats.items():
            total_stats[k] = total_stats.get(k, 0) + int(v)

    try:
        with open(args.output_json, "w", encoding="utf-8") as f:
            json.dump(all_rows, f, ensure_ascii=False, indent=2)
        with open(args.output_new, "w", encoding="utf-8") as f:
            json.dump(all_flat, f, ensure_ascii=False, indent=2)
        print(f"\n[final] wrote arrays → {args.output_json} & {args.output_new} (detail_count={len(all_rows)} per-level_count={len(all_flat)})")
    except Exception as e:
        print(f"[final][warn] failed to write arrays: {e}")

    print(
        "[stats-total] "
        f"ok(full)={total_stats.get('ok',0)} | dropped_disqualified={total - total_stats.get('ok',0)} | "
        f"no_wiki={total_stats.get('no_wiki',0)} | below_threshold={total_stats.get('below_threshold',0)} | facts_empty={total_stats.get('facts_empty',0)}"
    )

if __name__ == "__main__":
    main()
