#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse, os, math, csv
from collections import defaultdict

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--qrels", required=True)
    ap.add_argument("--runs", nargs="+", required=True, help='FORMAT: "<path>:<label>" or "<path>"')
    ap.add_argument("--k-list", nargs="+", type=int, default=[10, 50])
    ap.add_argument("--mrr-at", type=int, default=10)
    ap.add_argument("--limit-qids", default=None)
    ap.add_argument("--csv-out", default=None, help="Optional path to write a summary CSV of all metrics.")
    return ap.parse_args()

def read_qrels(path):
    rel = defaultdict(set)
    with open(path, encoding="utf-8", errors="ignore") as f:
        for i, line in enumerate(f):
            parts = line.strip().split()
            if not parts: continue
            # Skip header-ish first line if not numeric qid
            if i == 0 and not parts[0].isdigit() and "query" in parts[0].lower(): 
                continue
            if len(parts) >= 3:
                q, d, s = parts[0], parts[1], parts[2]
                if s != "0":
                    rel[q].add(d)
    return rel

def read_run(path, cut=1000):
    mp = defaultdict(list)  # qid -> [(doc, rank, score)]
    with open(path, encoding="utf-8", errors="ignore") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 6: 
                continue
            q, _, d, r, s, _ = parts[:6]
            try:
                r = int(r); s = float(s)
            except:
                continue
            if r <= cut:
                mp[q].append((d, r, s))
    for q in mp:
        mp[q].sort(key=lambda x: x[1])
    return mp

def dcg_at_k(docs, relset, k):
    dcg = 0.0
    for i, (d, _, _) in enumerate(docs[:k], start=1):
        if d in relset:
            dcg += 1.0 / math.log2(i + 1)  # binary gain
    return dcg

def idcg_at_k(relset, k):
    # best case is having min(|rel|, k) relevant at top
    m = min(len(relset), k)
    idcg = 0.0
    for i in range(1, m+1):
        idcg += 1.0 / math.log2(i + 1)
    return idcg

def rr_at_k(docs, relset, k):
    for i, (d, _, _) in enumerate(docs[:k], start=1):
        if d in relset:
            return 1.0 / i
    return 0.0

def ap_at_k(docs, relset, k):
    hits = 0
    ap = 0.0
    for i, (d, _, _) in enumerate(docs[:k], start=1):
        if d in relset:
            hits += 1
            ap += hits / i
    return ap / hits if hits > 0 else 0.0

def eval_run_file(run_path, qrels, k_list=(10, 50), mrr_at=10, limit_qids=None):
    mp = read_run(run_path, cut=max(k_list + [mrr_at, 100]))
    if limit_qids:
        keep = set(limit_qids.strip().split(","))
        mp = {q: v for q, v in mp.items() if q in keep}
    qs = sorted(set(mp) & set(qrels))
    if not qs:
        return {"n": 0}

    out = {"n": len(qs)}
    # Recall@K
    for K in k_list:
        R = sum(1 for q in qs if any(d in qrels[q] for d, _, _ in mp[q][:K]))
        out[f"R@{K}"] = R / len(qs)
    # MRR@k
    out["MRR@{}".format(mrr_at)] = sum(rr_at_k(mp[q], qrels[q], mrr_at) for q in qs) / len(qs)
    # nDCG@K (binary)
    for K in k_list:
        ndcgs = []
        for q in qs:
            rel = qrels[q]
            dcg = dcg_at_k(mp[q], rel, K)
            idcg = idcg_at_k(rel, K)
            ndcgs.append(dcg / idcg if idcg > 0 else 0.0)
        out[f"nDCG@{K}"] = sum(ndcgs) / len(qs)
    # MAP@100 (cap at 100 to be stable)
    Kmap = 100
    out[f"MAP@{Kmap}"] = sum(ap_at_k(mp[q], qrels[q], Kmap) for q in qs) / len(qs)
    return out

def main():
    args = parse_args()
    qrels = read_qrels(args.qrels)
    ks = list(sorted(set(args.k_list)))
    rows_for_csv = []

    # Build dynamic header names in the exact order we'll print
    header = ["name", "n"] \
           + [f"R@{K}" for K in ks] \
           + [f"MRR@{args.mrr_at}"] \
           + [f"nDCG@{K}" for K in ks] \
           + ["MAP@100", "-> path"]

    # Pretty-printer: fixed widths so columns line up
    def fmt_line(values):
        # values is a list of strings matching header order
        parts = []
        parts.append(f"{values[0]:<20s}")      # name (left, width 20)
        parts.append(f"{values[1]:>6s}")       # n (right, width 6)
        # metrics (right, width 10), everything except last column (path)
        for v in values[2:-1]:
            parts.append(f"{v:>10s}")
        # path printed raw at the end (keeps full path)
        parts.append(values[-1])
        return " ".join(parts)

    # Print header once
    print(fmt_line(header))

    for spec in args.runs:
        if ":" in spec:
            path, label = spec.split(":", 1)
        else:
            path, label = spec, os.path.basename(spec)

        m = eval_run_file(
            path, qrels,
            k_list=ks, mrr_at=args.mrr_at,
            limit_qids=args.limit_qids
        )

        if not m or m.get("n", 0) == 0:
            # zero row (aligned with header)
            zero_vals = ["0.000"] * (len(header) - 3)  # all metrics except 'name','n','-> path'
            line_vals = [label, "0"] + zero_vals + [path]
            print(fmt_line(line_vals))
            continue

        # Build the row (strings) in the same order as header
        row_vals = [label, str(m["n"])]
        row_vals += [f"{m[f'R@{K}']:.3f}" for K in ks]
        row_vals += [f"{m[f'MRR@{args.mrr_at}']:.3f}"]
        row_vals += [f"{m[f'nDCG@{K}']:.3f}" for K in ks]
        row_vals += [f"{m['MAP@100']:.3f}", path]

        print(fmt_line(row_vals))

        # Also populate the CSV row (unchanged semantics)
        csv_row = {
            "name": label, "path": path, "n": m["n"],
            f"MRR@{args.mrr_at}": m[f"MRR@{args.mrr_at}"],
            "MAP@100": m["MAP@100"],
        }
        for K in ks:
            csv_row[f"R@{K}"] = m[f"R@{K}"]
            csv_row[f"nDCG@{K}"] = m[f"nDCG@{K}"]
        rows_for_csv.append(csv_row)

    if args.csv_out:
        fieldnames = (
            ["name", "path", "n"]
            + [f"R@{K}" for K in ks]
            + [f"MRR@{args.mrr_at}"]
            + [f"nDCG@{K}" for K in ks]
            + ["MAP@100"]
        )
        with open(args.csv_out, "w", newline="", encoding="utf-8") as fw:
            cw = csv.DictWriter(fw, fieldnames=fieldnames)
            cw.writeheader()
            cw.writerows(rows_for_csv)

if __name__ == "__main__":
    main()