import argparse
import json
import os
from tqdm import tqdm
from search import *
from flashrag.utils import get_retriever
from api import OpenAIWrapper


def search_and_format_knowledge(prompt, text_retriever, num=15, start=1):
    # Retrieve search results and scores
    retrieval_results, scores = text_retriever.search(
        prompt, num=num, return_score=True
    )

    # Ensure each retrieval result has a title and text
    for doc_item, score in zip(retrieval_results, scores):
        if "title" not in doc_item or "text" not in doc_item:
            parts = doc_item["contents"].split("\n")
            doc_item["title"] = parts[0]
            doc_item["text"] = "\n".join(parts[1:])

    # Format the retrieved passages into a readable string
    formatted_ref = ""
    for idx, passage in enumerate(retrieval_results, start):
        formatted_ref += f"Passage #{idx} Title: {passage['title']}\n"
        formatted_ref += f"Passage #{idx} Text: {passage['text']}\n\n"

    return formatted_ref


def main():
    parser = argparse.ArgumentParser(
        description="Add reasoning_record to InfoSeek few-shot JSONL"
    )
    parser.add_argument(
        "--input", type=str, required=True, help="Path to InfoSeek_few_shot.jsonl"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="InfoSeek_few_shot_with_reasoning_record.jsonl",
        help="Output JSONL path",
    )
    parser.add_argument(
        "--topk", type=int, default=10, help="Number of text passages to retrieve"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="google/gemma-3-4b-it",
        help="Model to use for summarization",
    )
    args = parser.parse_args()

    text_retriever = get_retriever(default_config)
    mm_retriever = FaissSearch(
        index_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_faiss.index",
        metadata_path="/drl_nas2/ckddls1321/data/InfoSeek/image_metadata_updated.csv",
        model_name="hf-hub:timm/ViT-SO400M-16-SigLIP2-384",
        search_caption=True,
        top_k=5,
    )

    # Initialize model wrapper
    model_wrapper = OpenAIWrapper(
        model=args.model,
        api_base="https://openrouter.ai/api/v1/chat/completions",
        key=os.environ.get("OPENROUTER_API_KEY", ""),
    )

    # Determine how many lines have already been processed
    processed = 0
    if os.path.exists(args.output):
        with open(args.output, "r") as fout:
            for processed, _ in enumerate(fout, 1):
                pass
    print(f"Resuming from line {processed + 1} of input")

    with open(args.input, "r") as fin, open(args.output, "a") as fout:
        for idx, line in enumerate(fin, 1):
            if idx <= processed:
                continue
            item = json.loads(line)

            reasoning_records = []
            for question, answer, image_path in zip(
                item["questions"], item["answers"], item["image_paths"]
            ):
                prompt_text = f"Question: {question}\n{answer}"
                knowledge = search_and_format_knowledge(
                    prompt_text, text_retriever, num=args.topk
                )

                image_text_pairs, distances = mm_retriever.search(
                    {"image_path": image_path, "caption": question + answer}
                )

                passage_prompt = "Use the above passages to support your reasoning."
                prompt = (
                    f"Question: {question}\n"
                    f"{knowledge}\n\n"
                    "Based on image, knowledge, concisely summarize correct and relevant information with image and question.\n"
                )
                reasoning = model_wrapper.get_prediction(
                    image_path=image_path,
                    prompt=prompt,
                    passages=image_text_pairs,
                    passage_prompt=passage_prompt,
                )
                reasoning_records.append(reasoning)

            item["reasoning_record"] = reasoning_records
            fout.write(json.dumps(item) + "\n")

    print(f"Finished writing with reasoning records to {args.output}")


if __name__ == "__main__":
    main()
