import argparse, collections

def read_run(path, cut=None):
    byq = collections.defaultdict(list)
    with open(path, encoding="utf-8", errors="ignore") as f:
        for ln in f:
            qid, _, docid, rank, score, tag = ln.split()
            r = int(rank)
            if cut is None or r <= cut:
                byq[qid].append((docid, r))
    return byq

def wrrf_fuse(run_dicts, K):
    """
    run_dicts: list of dicts {qid -> [(docid, rank), ...]}
    Uniform WRRF across all provided runs.
    """
    out = {}
    all_qids = set()
    for rd in run_dicts:
        all_qids |= set(rd.keys())
    for q in sorted(all_qids):
        sc = collections.defaultdict(float)
        for rd in run_dicts:
            for doc, r in rd.get(q, []):
                sc[doc] += 1.0 / (K + r)
        # sort by fused score desc, then docid for stable tie‑break
        ranked = sorted(sc.items(), key=lambda x: (-x[1], x[0]))
        out[q] = ranked
    return out

def write_trec(fused, out_path, topk):
    with open(out_path, "w") as w:
        for q in sorted(fused):
            for r, (d, s) in enumerate(fused[q][:topk], start=1):
                w.write(f"{q} Q0 {d} {r} {s:.6f} FUSED_UNION_WRRF\n")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bm25", required=True)
    ap.add_argument("--dense", required=True)
    ap.add_argument("--extra", nargs="*", default=[], help="Optional additional TREC run(s) to include in the union")
    ap.add_argument("--cutoff", type=int, default=100)
    ap.add_argument("--wrrf_K", type=int, default=30)  # stronger smoothing than 60
    ap.add_argument("--out", required=True)
    args = ap.parse_args()

    bm25  = read_run(args.bm25,  cut=args.cutoff)             # use up to cutoff ranks
    dense = read_run(args.dense, cut=min(50, args.cutoff))    # dense@50 typical
    extras = [read_run(p, cut=min(50, args.cutoff)) for p in args.extra]
    fused = wrrf_fuse([bm25, dense] + extras, K=args.wrrf_K)
    write_trec(fused, args.out, topk=args.cutoff)
    print("WROTE", args.out)

if __name__ == "__main__":
    main()
