# scripts/make_dense_run.py
import argparse, json, gzip, sys, os
from tqdm import tqdm
import numpy as np

def read_jsonl_gz(fp):
    with gzip.open(fp, "rt", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if line.strip():
                yield json.loads(line)

def batched(iterable, n):
    batch = []
    for x in iterable:
        batch.append(x)
        if len(batch) == n:
            yield batch
            batch = []
    if batch:
        yield batch

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--corpus", required=True, help=".../corpus.jsonl.gz (BEIR: _id, text/title)")
    p.add_argument("--queries", required=True, help=".../queries.jsonl.gz (BEIR: _id, text)")
    p.add_argument("--model", default="BAAI/bge-base-en-v1.5")
    p.add_argument("--device", default="cuda")
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--k", type=int, default=50)
    p.add_argument("--normalize", action="store_true",
                   help="L2-normalize (cosine via inner product). Recommended for BGE.")
    p.add_argument("--out_jsonl", required=True, help="Per-qid hits JSONL")
    p.add_argument("--out_trec", required=True, help="TREC run file")
    p.add_argument("--run_name", default="bge_dense")
    args = p.parse_args()

    from sentence_transformers import SentenceTransformer
    import faiss

    print(f"Loading encoder: {args.model} on {args.device}", flush=True)
    model = SentenceTransformer(args.model, device=args.device)
    model.max_seq_length = 512

    # Prepare FAISS index (IP + optional L2 norm => cosine)
    dim = model.get_sentence_embedding_dimension()
    index = faiss.IndexFlatIP(dim)

    # For multi-threaded FAISS on CPU searches, you can uncomment:
    # faiss.omp_set_num_threads(min(16, os.cpu_count() or 1))

    # 1) Stream corpus -> encode by batch -> add to FAISS
    doc_ids = []
    total_docs = 0

    def corpus_records():
        for rec in read_jsonl_gz(args.corpus):
            did = rec.get("_id") or rec.get("docid") or rec.get("id")
            text = (rec.get("title") + " " if rec.get("title") else "") + (rec.get("text") or rec.get("contents") or "")
            if did and text:
                yield (str(did), text)

    print("Encoding corpus (streamed)…", flush=True)
    for batch in tqdm(batched(corpus_records(), args.batch_size), desc="Corpus", unit="batch"):
        ids_b = [did for did, _ in batch]
        txt_b = [txt for _, txt in batch]
        emb_b = model.encode(
            txt_b,
            batch_size=len(txt_b),
            convert_to_numpy=True,
            show_progress_bar=False,
            normalize_embeddings=False  # we normalize once, below
        ).astype("float32")

        if args.normalize:
            faiss.normalize_L2(emb_b)

        index.add(emb_b)              # add directly; no giant D matrix
        doc_ids.extend(ids_b)
        total_docs += len(ids_b)

    print(f"Corpus indexed: {total_docs} docs; FAISS ntotal={index.ntotal}", flush=True)

    # 2) Read & encode queries
    qids, qtexts = [], []
    for rec in read_jsonl_gz(args.queries):
        qid = rec.get("_id") or rec.get("qid") or rec.get("id")
        qtxt = rec.get("text") or rec.get("query") or rec.get("question") or rec.get("contents")
        if qid and qtxt:
            qids.append(str(qid)); qtexts.append(qtxt)
    Q = len(qids)
    print(f"Queries loaded: {Q}", flush=True)

    # 3) Search per query (batched)
    K = args.k
    os.makedirs(os.path.dirname(args.out_jsonl), exist_ok=True)
    os.makedirs(os.path.dirname(args.out_trec), exist_ok=True)
    with open(args.out_jsonl, "w", encoding="utf-8") as outj, \
         open(args.out_trec, "w", encoding="utf-8") as outt:

        for qb in tqdm(batched(list(zip(qids, qtexts)), args.batch_size), desc="Searching", unit="batch"):
            qids_b = [q for q, _ in qb]
            qtxts_b = [t for _, t in qb]
            Qemb = model.encode(
                qtxts_b,
                batch_size=len(qtxts_b),
                convert_to_numpy=True,
                show_progress_bar=False,
                normalize_embeddings=False
            ).astype("float32")
            if args.normalize:
                faiss.normalize_L2(Qemb)

            sims, idxs = index.search(Qemb, K)  # (B, K)
            for i, qid in enumerate(qids_b):
                hits = []
                for r, (di, sc) in enumerate(zip(idxs[i], sims[i]), start=1):
                    did = doc_ids[int(di)]
                    hits.append({"docid": did, "score": float(sc)})
                    outt.write(f"{qid} Q0 {did} {r} {sc:.6f} {args.run_name}\n")
                outj.write(json.dumps({"qid": qid, "hits": hits}, ensure_ascii=False) + "\n")

    print(f"Wrote JSONL: {args.out_jsonl}")
    print(f"Wrote TREC : {args.out_trec}")

if __name__ == "__main__":
    main()