#!/usr/bin/env python3
import argparse, collections, math, sys

# ---------- IO ----------

def read_trec(path, cutoff=None):
    """
    Read a TREC run. Returns dict[qid] -> list of (docid, base_rank, base_score, tag)
    Sorted by base_rank ascending.
    """
    byq = collections.defaultdict(list)
    with open(path, encoding="utf-8", errors="ignore") as f:
        for ln in f:
            parts = ln.strip().split()
            if len(parts) < 6:
                continue
            qid, _, docid, r, s, tag = parts[:6]
            try:
                r = int(r)
                s = float(s)
            except:
                continue
            byq[qid].append([docid, r, s, tag])
    # order by given rank, optionally truncate
    for q in byq:
        byq[q].sort(key=lambda x: x[1])
        if cutoff:
            byq[q] = byq[q][:cutoff]
    return byq

def read_ges_trec(path):
    """
    Sparse GES run: returns dict[qid]-> dict[docid] = ges_weight (float)
    """
    out = collections.defaultdict(dict)
    with open(path, encoding="utf-8", errors="ignore") as f:
        for ln in f:
            parts = ln.strip().split()
            if len(parts) < 6:
                continue
            qid, _, docid, r, s, tag = parts[:6]
            try:
                val = float(s)
            except:
                continue
            out[qid][docid] = val
    return out

def topk_sets(path, k):
    """
    For gating: returns dict[qid] -> set(docids in top-k)
    """
    byq = collections.defaultdict(list)
    with open(path, encoding="utf-8", errors="ignore") as f:
        for ln in f:
            parts = ln.strip().split()
            if len(parts) < 6:
                continue
            qid, _, docid, r, s, tag = parts[:6]
            try:
                r = int(r)
            except:
                continue
            if r <= k:
                byq[qid].append((r,docid))
    out = {}
    for q, lst in byq.items():
        lst.sort()
        out[q] = set(d for _,d in lst[:k])
    return out

# ---------- Helpers ----------

def safe_quantile(xs, q):
    if not xs: return None
    xs = sorted(xs)
    pos = (len(xs)-1)*q
    lo = math.floor(pos)
    hi = math.ceil(pos)
    if lo==hi: return xs[lo]
    w = pos - lo
    return xs[lo]*(1-w) + xs[hi]*w

def compute_lambda_needed(base_score, ges_val, threshold_score):
    """
    Solve for λ in: base*(1 + λ*ges) > threshold
    => λ > (threshold/base - 1) / ges
    Returns +inf if ges_val<=0 or base_score<=0, else the positive bound.
    """
    if ges_val <= 0 or base_score <= 0:
        return math.inf
    return (threshold_score/base_score - 1.0) / ges_val

# ---------- Core ----------

def poe_selective_multiply(
    base_run, ges_run, out_path,
    k_boost=10, max_cites=2, min_ges=0.0,
    bm25_path=None, bge_path=None, agree_in_k=None, agree_need_any=False,
    lambda_fixed=None, auto_lambda=False, lambda_cap=0.15, cutoff_target=10,
    freeze_top1=False, freeze_head_k=0, max_jump=None,
    inspect_k=0, cutoff=100
):
    base = read_trec(base_run)
    ges  = read_ges_trec(ges_run)

    # Optional agreement gating
    bm25_top = topk_sets(bm25_path, agree_in_k) if (bm25_path and agree_in_k) else {}
    bge_top  = topk_sets(bge_path,  agree_in_k) if (bge_path  and agree_in_k) else {}

    wrote_q = 0
    changed_k = 0
    same_k = 0

    with open(out_path, "w", encoding="utf-8") as w:
        for q, lst in base.items():
            if cutoff and len(lst) > cutoff:
                lst = lst[:cutoff]

            # snapshot base head for movement stats
            base_head = [d for d,_,_,_ in lst[:inspect_k]] if inspect_k>0 else []

            # Build quick index maps
            base_scores = {docid: score for (docid, _, score, _) in lst}
            base_rank   = {docid: rank  for (docid, rank, _,   _) in lst}

            # candidates within top-k_boost
            cand_pool = lst[:k_boost] if k_boost>0 else lst
            ges_q = ges.get(q, {})
            cands = []
            for (docid, r, s, tag) in cand_pool:
                g = ges_q.get(docid, 0.0)
                if g < min_ges:
                    continue
                # Optional agreement gate: doc must be present in (both|either)
                if agree_in_k:
                    bm = bm25_top.get(q, set())
                    bg = bge_top.get(q, set())
                    if agree_need_any:
                        if not ((docid in bm) or (docid in bg)):
                            continue
                    else:
                        if not (docid in bm and docid in bg):
                            continue
                cands.append((docid, g, r, s))

            # choose up to max_cites by highest GES
            cands.sort(key=lambda x: x[1], reverse=True)
            cands = cands[:max_cites]

            # Decide λ
            if auto_lambda:
                # threshold is the score of doc at cutoff_target (1-indexed)
                if 1 <= cutoff_target <= len(lst):
                    thresh_score = lst[cutoff_target-1][2]
                else:
                    thresh_score = lst[-1][2] if lst else 0.0

                needed = []
                for (docid, g, r, s) in cands:
                    needed.append(compute_lambda_needed(s, g, thresh_score))
                # pick a gentle quantile and cap
                lam_eff = safe_quantile([x for x in needed if x>0 and x!=math.inf], 0.25)
                if lam_eff is None or lam_eff < 0:
                    lam_eff = 0.0
                lam_eff = min(lam_eff, lambda_cap)
            else:
                lam_eff = lambda_fixed if lambda_fixed is not None else 0.0

            boost_set = {d for (d,_,_,_) in cands if lam_eff>0}

            # Compute preliminary new scores (multiplicative)
            EPS = 1e-12
            new_scores = {}
            for (docid, r, s, tag) in lst:
                if docid in boost_set:
                    g = ges_q.get(docid, 0.0)
                    new_scores[docid] = s * (1.0 + lam_eff * g)
                else:
                    new_scores[docid] = s

            # ---------- Guardrail 1: freeze head ----------
            freeze_k = max(freeze_head_k, 1 if freeze_top1 else 0)
            head_docs = [d for (d,_,_,_) in lst[:freeze_k]]
            tail_docs = [d for (d,_,_,_) in lst[freeze_k:]]

            # ---------- Guardrail 2: max jump ----------
            if max_jump is not None and max_jump > 0:
                # Cap any doc's score so it cannot move above (base_rank - max_jump)
                # Compute score thresholds from base ordering.
                # rank idx is 1-based in base_rank dict; we need the score of the doc sitting at r_allowed-1
                base_order = [d for (d,_,_,_) in lst]
                base_score_at_rank = {i+1: base_scores[base_order[i]] for i in range(len(base_order))}
                for d in tail_docs:  # head already frozen anyway
                    r0 = base_rank[d]
                    r_allowed = max(1, r0 - max_jump)
                    # respect frozen head: never outrank position 'freeze_k'
                    r_allowed = max(r_allowed, freeze_k+1)  # cannot enter frozen head
                    if r_allowed <= 1:
                        continue
                    # The doc cannot exceed the score of the doc at (r_allowed-1)
                    if (r_allowed-1) in base_score_at_rank:
                        cap_score = base_score_at_rank[r_allowed-1] - 1e-9
                        if new_scores[d] > cap_score:
                            new_scores[d] = cap_score

            # Now sort tail by new score desc, tiebreak by base rank asc
            tail_sorted = sorted(
                tail_docs,
                key=lambda d: (-new_scores[d], base_rank[d])
            )

            # Final order = frozen head (preserve original order) + tail_sorted
            final_docs = head_docs + tail_sorted
            if cutoff:
                final_docs = final_docs[:cutoff]

            # Movement inspection
            if inspect_k > 0:
                new_head = final_docs[:inspect_k]
                changed_k += int(new_head != base_head)
                same_k    += int(new_head == base_head)

            # Write TREC
            for new_r, d in enumerate(final_docs, start=1):
                w.write(f"{q} Q0 {d} {new_r} {new_scores[d]:.6f} POE_SEL_MUL_GUARDED\n")

            wrote_q += 1

    print(f"WROTE {out_path} | queries={wrote_q} | "
          f"k_boost={k_boost} max_cites={max_cites} min_ges={min_ges} "
          f"{'auto_lambda cap='+str(lambda_cap) if auto_lambda else 'lambda='+str(lambda_fixed)} "
          f"| cutoff={cutoff}"
          f"{' | freeze_top1' if freeze_top1 else ''}"
          f"{' | freeze_head_k='+str(freeze_head_k) if freeze_head_k else ''}"
          f"{' | max_jump='+str(max_jump) if (max_jump is not None) else ''}")
    if inspect_k > 0:
        tot = changed_k + same_k
        frac = (changed_k / tot) if tot else 0.0
        print(f"Head@{inspect_k} movement: changed={changed_k} same={same_k} ({frac:.1%} changed)")

# ---------- CLI ----------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_run", required=True)
    ap.add_argument("--ges_run",  required=True)
    ap.add_argument("--out",      required=True)

    ap.add_argument("--k_boost", type=int, default=10)
    ap.add_argument("--max_cites", type=int, default=2)
    ap.add_argument("--min_ges", type=float, default=0.0)

    # Optional agreement gating
    ap.add_argument("--bm25", default=None)
    ap.add_argument("--bge",  default=None)
    ap.add_argument("--agree_in_k", type=int, default=None,
                    help="If set, require doc in BM25/BGE top-K; default requires BOTH. Use --agree_need_any for OR.")
    ap.add_argument("--agree_need_any", action="store_true", help="If set, doc may appear in either BM25 or BGE top-K.")

    # Lambda config
    ap.add_argument("--lambda_", type=float, default=None, help="Fixed lambda (ignored if --auto_lambda).")
    ap.add_argument("--auto_lambda", action="store_true")
    ap.add_argument("--lambda_cap", type=float, default=0.15)
    ap.add_argument("--cutoff_target", type=int, default=10)

    # Guardrails
    ap.add_argument("--freeze_top1", action="store_true")
    ap.add_argument("--freeze_head_k", type=int, default=0, help="Freeze top-K head from the base run.")
    ap.add_argument("--max_jump", type=int, default=None, help="Limit max upward rank improvement (in absolute rank positions).")

    # Misc
    ap.add_argument("--inspect_k", type=int, default=0)
    ap.add_argument("--cutoff", type=int, default=100)

    args = ap.parse_args()

    poe_selective_multiply(
        base_run=args.base_run,
        ges_run=args.ges_run,
        out_path=args.out,
        k_boost=args.k_boost,
        max_cites=args.max_cites,
        min_ges=args.min_ges,
        bm25_path=args.bm25,
        bge_path=args.bge,
        agree_in_k=args.agree_in_k,
        agree_need_any=args.agree_need_any,
        lambda_fixed=args.lambda_,
        auto_lambda=args.auto_lambda,
        lambda_cap=args.lambda_cap,
        cutoff_target=args.cutoff_target,
        freeze_top1=args.freeze_top1,
        freeze_head_k=args.freeze_head_k,
        max_jump=args.max_jump,
        inspect_k=args.inspect_k,
        cutoff=args.cutoff
    )

if __name__ == "__main__":
    main()
