import argparse
import json
import time
from pathlib import Path

import jsonl
from haystack import Document, Pipeline
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from tqdm import tqdm


def truncate(query: str):
    if len(query) > 500:
        query = query[:250] + "..." + query[-250:]
    return query


def get_content(query: str, response: str):
    return f"**Q)** {query}\n**A)** {response}"


def get_document_content(doc: dict):
    question = truncate(doc["question"])
    if "choices" in doc:
        choice_texts = [f"{o}. {c}" for o, c in zip("ABCDE", doc["choices"])]
        question += "\n" + "\n".join(choice_texts)
        response = choice_texts[doc["answer_idx"]]
    else:
        response = doc["answer"]

    return get_content(query=question, response=response)


def retrieve(data: list[dict[str]], top_k: int = 5):
    print("Prepare DocumentStore...")
    document_store = InMemoryDocumentStore()
    n = document_store.write_documents(
        [
            Document(id=row["document"]["id"], content=get_document_content(row["document"]), meta=row["document"])
            for row in data
            if row["type"] == "forget"
        ],
        policy="skip",
    )
    print(f"Wrote {n} documents to store.")

    print("Initialize pipeline:")
    pipeline = Pipeline()

    retriever = InMemoryBM25Retriever(document_store=document_store, top_k=top_k)
    pipeline.add_component(name="retriever", instance=retriever)

    n = 0
    mrr = 0
    hits = 0
    start_time = time.time()
    for item in tqdm(data, desc="Retrieve"):
        content = get_content(query=item["query"], response=item["response"])
        input_dict = {"retriever": {"query": content}}

        output = pipeline.run(input_dict)
        retrieved_docs = [
            {k: v for k, v in doc.to_dict().items() if v is not None} for doc in output["retriever"]["documents"]
        ]
        if item["type"] == "forget":
            rank = next(
                (i + 1 for i, d in enumerate(retrieved_docs) if d["id"] == item["document"]["id"]),
                None,
            )
            n += 1
            hit = rank is not None
            if hit:
                mrr += 1 / rank
                hits += 1
        else:
            hit = rank = None
        item["retrieval"] = {
            "query": content,
            "hit": hit,
            "rank": rank,
            "documents": retrieved_docs,
        }

    end_time = time.time()
    total_time = end_time - start_time
    print(f"Finished retrieval: {total_time:.2f} seconds.")

    return {
        "mrr": mrr / n,
        f"hit@{top_k}": hits / n,
        "processing_time_sec": total_time,
        "queries_per_second": len(data) / total_time,
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", type=str, required=True)
    parser.add_argument("--k", type=int, default=5)
    args = parser.parse_args()

    input_path = Path(args.input)

    data = list(jsonl.load(input_path))
    retrieve_path = input_path.with_stem(input_path.stem + "-bm25")

    # if retrieve_path.exists():
    #     print(f"{retrieve_path} already exists.")
    #     return

    metrics = retrieve(data, top_k=args.k)
    jsonl.dump(data, retrieve_path)
    with retrieve_path.with_stem(retrieve_path.stem + "_result").with_suffix(".json").open("w") as fp:
        json.dump(metrics, fp, indent=4)


if __name__ == "__main__":
    main()
