#!/usr/bin/env python3
import argparse, json, re, sys
from collections import defaultdict

CITE_RE = re.compile(r"\[CITE:\s*([^\]]+)\]")

def parse_range(s):
    a,b = s.split("-",1)
    return int(a), int(b)

def load_union(path):
    # qid -> ordered list of docids by rank (1-indexed positions)
    order = defaultdict(list)
    with open(path, encoding="utf-8", errors="ignore") as f:
        for line in f:
            p = line.split()
            if len(p) < 6:        # qid Q0 docid rank score tag
                continue
            qid = p[0]
            docid = p[2]
            # rank = int(p[3])   # rank not needed since file is already rank-sorted
            order[qid].append(docid)
    return order

def count_cites(text):
    counts = defaultdict(int)
    for m in CITE_RE.finditer(text or ""):
        for tok in m.group(1).split(","):
            d = tok.strip()
            if d:
                counts[d] += 1
    return counts

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("union_trec")     # to map pack ranges -> docids
    ap.add_argument("portfolio_jsonl")# usually responses.synthesis.jsonl from portfolio script
    ap.add_argument("out_trec")
    ap.add_argument("--max_cites", type=int, default=3)
    ap.add_argument("--weights", default="1.0,0.6,0.6",
                    help="scores for top-k cited docs per query (comma sep)")
    ap.add_argument("--restrict_to_pool", default=None,
                    help="Optional TREC run; if set, only emit docids present in this pool (e.g., BGE).")
    args = ap.parse_args()

    weights = [float(x) for x in args.weights.split(",")]
    union = load_union(args.union_trec)
    pool_ok = None
    if args.restrict_to_pool:
        pool_ok = set()
        with open(args.restrict_to_pool, encoding="utf-8", errors="ignore") as f:
            for ln in f:
                p = ln.split()
                if len(p) >= 3:
                    pool_ok.add(p[2])

    wrote = 0
    q_seen = set()
    with open(args.out_trec, "w", encoding="utf-8") as out, \
         open(args.portfolio_jsonl, "r", encoding="utf-8", errors="ignore") as pj:
        for line in pj:
            if not line.strip(): continue
            obj = json.loads(line)
            qid = str(obj.get("qid", ""))
            if not qid: continue
            q_seen.add(qid)

            # portfolio structure: packs = [{ "range": "1-12", "synthesis": "..." }, ...]
            packs = obj.get("packs") or []
            # Support legacy text format with inline VARIANT headers.
            if not packs and "synthesis" in obj:
                text = obj["synthesis"]
                cur = []
                import re
                hdr = re.compile(r'^\s*//\s*VARIANT\s+\d+\s*\((\d+)\s*-\s*(\d+)\)\s*$', re.I)
                lo = hi = None
                buf = []
                for line in text.splitlines():
                    m = hdr.match(line)
                    if m:
                        if buf and lo and hi:
                            packs.append({"range": f"{lo}-{hi}", "synthesis": "\n".join(buf).strip()})
                            buf = []
                        lo, hi = m.group(1), m.group(2)
                    else:
                        buf.append(line)
                if buf and lo and hi:
                    packs.append({"range": f"{lo}-{hi}", "synthesis": "\n".join(buf).strip()})
                # If we still failed, fall back to probe format (D# maps directly to union positions)
                if not packs:
                    packs = [{"range": "probe", "synthesis": text}]

            # build global docid <- total_citation_count
            glob_counts = defaultdict(float)
            ranked = union.get(qid, [])
            for pk in packs:
                r = str(pk.get("range","")).replace(" ", "")
                synth = pk.get("synthesis","") or ""
                local = count_cites(synth)  # D1..Dk within this pack
                
                if r == "probe":
                    # Probe format: D# maps directly to union rank # (D1 -> union[0], D2 -> union[1], etc.)
                    for dtag, c in local.items():
                        if not dtag.startswith("D"): continue
                        try:
                            idx = int(dtag[1:])  # 1-based
                        except Exception:
                            continue
                        if 1 <= idx <= len(ranked):
                            docid = ranked[idx-1]  # Direct mapping to union position
                            glob_counts[docid] += c
                else:
                    # Portfolio format: D# maps to position within range (lo..hi)
                    try:
                        lo, hi = parse_range(r)
                    except Exception:
                        continue
                    # map Dn -> global position (lo..hi)
                    for dtag, c in local.items():
                        if not dtag.startswith("D"): continue
                        try:
                            idx = int(dtag[1:])  # 1-based inside the pack
                        except Exception:
                            continue
                        pos = (lo - 1) + idx  # 1-based global rank position
                        if 1 <= pos <= len(ranked):
                            docid = ranked[pos-1]
                            glob_counts[docid] += c

            # select top docs by total cite count; tie-break by union order
            if glob_counts:
                ordered = sorted(glob_counts.items(),
                                 key=lambda kv: (-kv[1], ranked.index(kv[0]) if kv[0] in ranked else 10**9))
            else:
                ordered = []

            # Optional pool restriction
            if pool_ok is not None and ordered:
                ordered = [(d,c) for (d,c) in ordered if d in pool_ok]
            k = min(args.max_cites, len(ordered), len(weights))
            for i in range(k):
                docid = ordered[i][0]
                score = weights[i]
                out.write(f"{qid} Q0 {docid} {i+1} {score:.3f} GES_PORT\n")
                wrote += 1

    print(f"WROTE {args.out_trec} | lines={wrote} | qids={len(q_seen)} | avg_per_q={(wrote/len(q_seen)) if q_seen else 0:.2f}")

if __name__ == "__main__":
    main()