import random
import json
import os

import ir_datasets

import utils.const as C

def get_dataset(dataset_name, name):
    dataset = ir_datasets.load(dataset_name)
    qrels = get_qrels(dataset, name)
    query_texts = get_query_texts(dataset, qrels)
    doc_texts = get_document_texts(dataset, qrels)
    qrels = filter_qrels_with_missing_docs(qrels, doc_texts)
    with open("cache/qrels/" + name + ".json", "w") as f:
        json.dump(qrels, f)
    return qrels, query_texts, doc_texts

def get_qrels(dataset, name):
    if os.path.isfile("cache/qrels/" + name + ".json"):
        print("[cache] loading qrels")
        return json.load(open("cache/qrels/" + name + ".json"))
    qrels = {}
    print(dataset)
    for qrel in dataset.qrels_iter():
        qid = qrel.query_id
        did = qrel.doc_id
        if qid not in qrels:
            qrels[qid] = []
        qrels[qid].append(did)
    qrels = subsample_qrels(qrels)
    return qrels

def get_qrels_by_len(qrels):
    qrels_by_len = {}
    for qid in qrels:
        v = len(qrels[qid])
        if v not in qrels_by_len:
            qrels_by_len[v] = []
        qrels_by_len[v].append(qid)
    return qrels_by_len

def subsample_qrels(qrels):
    qrels_by_len = get_qrels_by_len(qrels)
    for l in qrels_by_len:
        if len(qrels_by_len[l]) < C.MAX_SAMPLES:
            continue
        qrels_by_len[l] = random.choices(qrels_by_len[l], k=C.MAX_SAMPLES)
    print([len(qrels_by_len[l]) for l in qrels_by_len])
    subsampled_qrels = {}
    for l in qrels_by_len:
        for qid in qrels_by_len[l]:
            subsampled_qrels[qid] = qrels[qid]
    
    return subsampled_qrels

def get_query_texts(dataset, qrels):
    relevant_qids = set(qrels.keys())
    query_texts = {}

    for query in dataset.queries_iter():
        if query.query_id in relevant_qids:
            text = getattr(query, "text", None) \
                or getattr(query, "description", None)
            query_texts[query.query_id] = text

    return query_texts

def get_document_texts(dataset, qrels):
    relevant_dids = {
        did
        for dids in qrels.values()
        for did in dids
    }

    doc_texts = {}

    for doc in dataset.docs_iter():
        if doc.doc_id in relevant_dids:
            text = getattr(doc, "text", None) \
                or getattr(doc, "detailed_description", None) \
                or getattr(doc, "title", "")  # final fallback
            doc_texts[doc.doc_id] = text
    return doc_texts

def filter_qrels_with_missing_docs(qrels, doc_texts):
    """Remove qrels entries (whole qid) if any doc_id is missing in doc_texts."""
    filtered = {}
    for qid, dids in qrels.items():
        if all(did in doc_texts for did in dids):
            filtered[qid] = dids
        else:
            missing = [did for did in dids if did not in doc_texts]
            print(f"[warn] Skipping qid {qid} due to missing docs: {missing}")
    return filtered
