import json
from tqdm import tqdm
from search import *
import argparse


def load_jsonl(path):
    data = []
    with open(path, "r") as f:
        for line in f:
            data.append(json.loads(line))
    return data


def check_match(retrieved_texts, answer_list, entity_text):
    """Check if any retrieved text matches any ground-truth answer"""
    if entity_text:
        if isinstance(entity_text, list):
            is_entity_exists = False
            for entity in entity_text:
                if "|" in entity:
                    names = entity.split("|")  # Some cases have |, (2-hop case)
                    all_names_exist = all(
                        any(
                            str(name).lower() in str(doc).lower()
                            for doc in retrieved_texts
                        )
                        for name in names
                    )
                    is_entity_exists |= all_names_exist
                else:
                    is_entity_exists |= any(
                        str(entity).lower() in str(doc).lower()
                        for doc in retrieved_texts
                    )
            return is_entity_exists
        else:
            is_entity_exist = any(
                str(entity_text).lower() in str(doc).lower() for doc in retrieved_texts
            )
            return is_entity_exist
    elif "[" in str(answer_list):
        is_exist = any(
            any(str(ans).lower() in str(doc).lower() for ans in answer_list)
            for doc in retrieved_texts
        )
        return is_exist


def compute_prr(dataset, retriever, top_ks=[5, 10, 20], use_question=True):
    hits = {k: 0 for k in top_ks}
    total = 0

    for item in tqdm(dataset, desc="Evaluating Recall"):
        image_path = item["image_path"]
        query = item["question"] 
        answers = (
            item["answer"] if isinstance(item["answer"], list) else [item["answer"]]
        )
        # if "answer_eval" in item:
        #     answers += answer_eval

        entity_text = item.get("entity_text", "")

        # Skip if no valid answers
        if not answers:
            continue

        query_input = {
            "image_path": image_path,
            "caption": query,
        }

        try:
            results, _ = retriever.search(query_input)
            docs = [
                res["caption"] if "caption" in res else res["text"] for res in results
            ]
        except Exception as e:
            print(f"Retrieval error for {item['question_id']}: {e}")
            continue

        total += 1
        for k in top_ks:
            top_k_docs = docs[:k]
            if check_match(top_k_docs, answers, entity_text):
                hits[k] += 1

    prr_scores = {f"PRR@{k}": hits[k] / total if total > 0 else 0.0 for k in top_ks}
    return prr_scores


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--jsonl_path", type=str, required=True)
    parser.add_argument("--topk", type=int, default=20)
    parser.add_argument(
        "--mode", type=str, choices=["image", "image+text"], default="image"
    )
    args = parser.parse_args()

    data = load_jsonl(args.jsonl_path)

    #"hf-hub:timm/ViT-gopt-16-SigLIP2-384"
    #"Qwen/Qwen3-Embedding-0.6B",
    if args.mode == "image":
        retriever = FaissSearch(
            metadata_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_metadata.csv",
            #index_path="/drl_nas2/ckddls1321/data/InfoSeek/image_faiss.index",
            #model_name="hf-hub:timm/ViT-SO400M-16-SigLIP2-384",
            #text_model="Alibaba-NLP/gte-modernbert-base",
            index_path="/drl_nas2/ckddls1321/data/InfoSeek/image_index_large.index",
            model_name="hf-hub:timm/ViT-gopt-16-SigLIP2-384",
            text_model="Qwen/Qwen3-Embedding-0.6B",
            top_k=args.topk,
        )
        use_question=False
    else:
        retriever = FaissSearch(
            metadata_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_metadata.csv",
            #index_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_faiss.index",
            #model_name="hf-hub:timm/ViT-SO400M-16-SigLIP2-384",
            #text_model="Alibaba-NLP/gte-modernbert-base",
            index_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_index_large2.index",
            model_name="hf-hub:timm/ViT-gopt-16-SigLIP2-384",
            text_model="Qwen/Qwen3-Embedding-0.6B",
            search_caption=True,
            top_k=args.topk,
        )
        use_question = True
    prr_results = compute_prr(
        data, retriever, top_ks=[5, 10, 20], use_question=use_question
    )
    print("\n==== PRR Results ====")
    for k, v in prr_results.items():
        print(f"{k}: {v:.4f}")
