#!/usr/bin/env python3
import argparse, csv, gzip, json, math, os, sys, statistics as stats
from collections import defaultdict

def _ws_split(line: str):
    return line.rstrip("\n").split("\t") if ("\t" in line) else line.rstrip("\n").split()

def read_qrels(path):
    """
    Robust reader:
      - Accepts 3-col (qid docid rel) or 4-col (qid Q0 docid rel)
      - Ignores headers (e.g., 'qid', 'query-id', 'score', etc.)
      - Accepts tabs or any whitespace
      - Keeps rel > 0 only
    """
    qrels = defaultdict(set)
    n_lines = 0
    n_pos = 0
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            n_lines += 1
            p = _ws_split(line)
            if not p: 
                continue
            # Skip header-ish first tokens
            h0 = p[0].lower()
            if h0 in {"qid","query-id","query","topic","topics","#"}:
                continue
            # Try patterns
            qid = did = rel = None
            if len(p) >= 4 and p[1].upper() == "Q0":
                # TREC-ish qrels (rare), assume rel at last slot
                qid, did = p[0], p[2]
                rel = p[-1]
            elif len(p) >= 3:
                qid, did, rel = p[0], p[1], p[-1]
            else:
                continue
            try:
                if int(rel) > 0:
                    qrels[qid].add(did)
                    n_pos += 1
            except:
                # non-numeric rel; ignore line
                continue
    print(f"[qrels] loaded: q={len(qrels)} queries, pos_judgments={n_pos}, from={path}")
    return qrels

def _try_int(tok):
    try: 
        return int(tok)
    except:
        return None

def _try_float(tok):
    try:
        return float(tok)
    except:
        return None

def read_run(path, topk=100):
    """
    Robust TREC run reader:
      - Accepts 4/5/6-col lines; typical is: qid Q0 docid rank score tag
      - If rank not at slot 4, find first integer token after docid
      - If score missing, fall back to a monotone proxy (-rank)
      - Keeps the best (lowest) rank if duplicates
    Returns: qid -> [(doc, rank, score)]
    """
    mp = defaultdict(dict)  # qid -> {doc: (rank, score)}
    n_lines = 0
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            n_lines += 1
            p = line.strip().split()
            if len(p) < 3:
                continue
            qid = p[0]
            # default doc at slot 2 in TREC; if not present, try to find a non-Q0 token after qid
            doc = p[2] if len(p) >= 3 else None
            if doc is None:
                # fallback (very rare)
                for tok in p[1:]:
                    if tok.upper() != "Q0":
                        doc = tok; break
                if doc is None:
                    continue
            # find rank: prefer slot 4; else scan the rest
            rank = None
            if len(p) >= 4:
                rank = _try_int(p[3])
            if rank is None:
                for tok in p[4:]:
                    rank = _try_int(tok)
                    if rank is not None:
                        break
            if rank is None:
                # cannot compute without rank
                continue
            # find score: prefer float right after rank; else any float later; else -rank
            score = None
            # common slot 5
            if len(p) >= 5:
                score = _try_float(p[4])
            if score is None:
                for tok in p[5:]:
                    score = _try_float(tok)
                    if score is not None:
                        break
            if score is None:
                score = -float(rank)
            # keep best rank per (qid,doc)
            prev = mp[qid].get(doc)
            if prev is None or rank < prev[0]:
                mp[qid][doc] = (rank, score)
    # convert to sorted lists & slice
    out = {}
    for qid, dmap in mp.items():
        lst = [(doc, r, s) for doc,(r,s) in dmap.items()]
        lst.sort(key=lambda x: x[1])
        out[qid] = lst[:topk]
    print(f"[run] loaded: q={len(out)} from={path} (lines={n_lines})")
    return out

def read_queries(path):
    """
    Queries reader:
      - If *.jsonl or *.jsonl.gz: expects {_id|id|qid, text|query}
      - If *.tsv: expects qid<TAB>text
    """
    d = {}
    if path.endswith(".tsv"):
        with open(path, encoding="utf-8", errors="ignore") as f:
            for line in f:
                p = _ws_split(line)
                if len(p) >= 2:
                    qid = p[0]
                    text = line.strip().split("\t",1)[1] if "\t" in line else " ".join(p[1:])
                    d[qid] = text
        print(f"[queries] loaded TSV: q={len(d)} from={path}")
        return d

    def get(o):
        return (o.get('_id') or o.get('id') or o.get('qid'),
                o.get('text') or o.get('query') or '')
    opn = gzip.open if path.endswith(".gz") else open
    n = 0
    with opn(path, 'rt', encoding='utf-8', errors='ignore') as f:
        for line in f:
            try:
                qid, text = get(json.loads(line))
                if qid is not None:
                    d[qid] = text
                    n += 1
            except:
                continue
    print(f"[queries] loaded JSONL: q={len(d)} (parsed_lines={n}) from={path}")
    return d

def rr_at_10(qid, run_list, qrels):
    for doc, r, _ in run_list[:10]:
        if doc in qrels.get(qid, ()):
            # rank is position, RR uses position; if r starts at 0, clamp to 1
            return 1.0 / max(1, r)
    return 0.0

def recall_at_10(qid, run_list, qrels):
    top = {doc for doc, r, _ in run_list[:10]}
    return 1 if any(doc in qrels.get(qid, ()) for doc in top) else 0

def curve_stats(scores10):
    if not scores10:
        return (0.0, 0.0, 0.0, 0.0, 0.0)
    top1 = scores10[0]
    s1 = scores10[1] if len(scores10) > 1 else scores10[0]
    s10 = scores10[min(9, len(scores10)-1)]
    gap12 = top1 - s1
    gap1_10 = top1 - s10
    mean10 = stats.fmean(scores10)
    std10 = stats.pstdev(scores10) if len(scores10) > 1 else 0.0
    return (top1, gap12, gap1_10, std10, mean10)

def jaccard(a, b):
    if not a and not b: return 0.0
    inter = len(a & b)
    union = len(a | b) if (a or b) else 1
    return inter / union

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--qrels', required=True)
    ap.add_argument('--bm25', required=True)
    ap.add_argument('--dense', required=True)
    ap.add_argument('--fused', required=True)
    ap.add_argument('--queries', required=True)
    ap.add_argument('--out', required=True)
    ap.add_argument('--topk', type=int, default=100)
    args = ap.parse_args()

    qrels = read_qrels(args.qrels)
    bm25  = read_run(args.bm25, args.topk)
    dense = read_run(args.dense, args.topk)
    fused = read_run(args.fused, args.topk)
    qtext = read_queries(args.queries)

    # Intersect on what we actually have everywhere
    qids = sorted(set(bm25) & set(dense) & set(fused) & set(qrels))
    print(f"[intersect] qids_in_all={len(qids)}  (bm25={len(bm25)}, dense={len(dense)}, fused={len(fused)}, qrels={len(qrels)})")
    if len(qids) == 0:
        # Helpful hints
        sample_b = next(iter(bm25.keys()), None)
        sample_q = next(iter(qrels.keys()), None)
        print(f"[warn] Intersection is empty. Example bm25 qid: {sample_b} ; example qrels qid: {sample_q}", file=sys.stderr)

    cols = [
      'qid','q_len_tok','q_len_char',
      'bm25_top1','bm25_gap12','bm25_gap1_10','bm25_std10','bm25_mean10',
      'dense_top1','dense_gap12','dense_gap1_10','dense_std10','dense_mean10',
      'overlap_bm25_dense_at10','jaccard_bm25_dense_at10',
      'rr10_bm25','rr10_dense','rr10_fused',
      'recall10_bm25','recall10_dense','recall10_fused',
      'label_best_rr10'
    ]
    with open(args.out, 'w', newline='', encoding='utf-8') as fout:
        out = csv.writer(fout)
        out.writerow(cols)

        for qid in qids:
            t = qtext.get(qid, '')
            q_len_tok = len(t.split())
            q_len_char = len(t)

            b10 = bm25[qid][:10]; d10 = dense[qid][:10]; f10 = fused[qid][:10]
            b_scores = [s for _,_,s in b10]
            d_scores = [s for _,_,s in d10]

            bt1, bg12, bg110, bstd, bmean = curve_stats(b_scores)
            dt1, dg12, dg110, dstd, dmean = curve_stats(d_scores)

            bset = {doc for doc,_,_ in b10}
            dset = {doc for doc,_,_ in d10}
            ol10 = len(bset & dset)
            jac = jaccard(bset, dset)

            rr_b = rr_at_10(qid, bm25[qid], qrels)
            rr_d = rr_at_10(qid, dense[qid], qrels)
            rr_f = rr_at_10(qid, fused[qid], qrels)
            rec_b = recall_at_10(qid, bm25[qid], qrels)
            rec_d = recall_at_10(qid, dense[qid], qrels)
            rec_f = recall_at_10(qid, fused[qid], qrels)

            # label from best rr10; tie-break favors fused > dense > bm25
            best = ('bm25', rr_b)
            if rr_d > best[1] or (rr_d == best[1] and best[0]=='bm25'):
                best = ('dense', rr_d)
            if rr_f > best[1] or (rr_f == best[1] and best[0] in ('bm25','dense')):
                best = ('fused', rr_f)

            out.writerow([
              qid, q_len_tok, q_len_char,
              bt1, bg12, bg110, bstd, bmean,
              dt1, dg12, dg110, dstd, dmean,
              ol10, jac,
              rr_b, rr_d, rr_f,
              rec_b, rec_d, rec_f,
              best[0]
            ])
    print(f"WROTE features: {args.out}  n={len(qids)}")

if __name__ == '__main__':
    main()