import argparse
import json
import os
import re
import pandas as pd
import pickle
import time
from api import (
    OpenAIWrapper,
    QWenWrapper,
    GemmaWrapper,
    OvisWrapper,
    InternVLWrapper,
)  # Assuming OpenAIWrapper is implemented in models folder
from search import *
from flashrag.utils import get_retriever, get_generator


def parse_args():
    parser = argparse.ArgumentParser(description="Test Medical VQA Model")
    parser.add_argument(
        "--data", type=str, required=True, help="Path to the input JSONL file"
    )
    parser.add_argument(
        "--model", type=str, required=True, help="Model name or OpenAI API key"
    )
    parser.add_argument(
        "--direct-answer", action="store_true", help="Use direct zero-shot answer"
    )
    parser.add_argument(
        "--iterretgen", action="store_true", help="Use direct zero-shot answer"
    )
    parser.add_argument(
        "--update", action="store_true", help="Use direct zero-shot answer"
    )
    parser.add_argument(
        "--history", action="store_true", help="Use direct zero-shot answer"
    )
    parser.add_argument(
        "--ask", action="store_true", help="Iterative Follow-up Question"
    )
    parser.add_argument(
        "--expansion", action="store_true", help="Query Expansion for search"
    )
    parser.add_argument(
        "--itercount", type=int, default=0, help="Number of iterative reasoning steps"
    )
    parser.add_argument(
        "--topk", type=int, default=10, help="Number of passages to retrieve"
    )
    parser.add_argument("--image-search", action="store_true", help="Use one more iter")
    parser.add_argument(
        "--verbose", action="store_true", help="Verbose output during processing"
    )
    return parser.parse_args()


def parse_questions(output):
    questions = re.findall(r"Question \d+: (.+?)(?=\n|$)", output.strip())
    return questions


def load_data(data_path):
    with open(data_path, "r") as file:
        return [json.loads(line.strip()) for line in file.readlines()]


def save_predictions(predictions, output_path):
    df = pd.DataFrame(predictions)
    df.to_csv(output_path, index=False)


def save_intermediate_results(predictions, temp_path):
    with open(temp_path, "wb") as temp_file:
        pickle.dump(predictions, temp_file)


def load_intermediate_results(temp_path):
    if os.path.exists(temp_path):
        with open(temp_path, "rb") as temp_file:
            return pickle.load(temp_file)
    return []


def extract_questions(text):
    # Use regex with IGNORECASE flag to capture everything after "Question 1:" until "Question 2:"
    pattern1 = re.compile(
        r"(?i)question\s*1\s*:\s*(.*?)\s*(?=question\s*2\s*:)", re.DOTALL
    )
    match1 = pattern1.search(text)
    question1 = match1.group(1).strip() if match1 else ""

    # Capture everything after "Question 2:" till the end of the text
    pattern2 = re.compile(r"(?i)question\s*2\s*:\s*(.*)", re.DOTALL)
    match2 = pattern2.search(text)
    question2 = match2.group(1).strip() if match2 else ""

    return question1, question2


def evaluate_retrieval(predictions):
    """Evaluate accuracy of predictions and return correctness flags."""
    correct = 0
    correctness_flags = []

    for prediction in predictions:
        is_correct = False
        knowledge = prediction["knowledge"].lower()
        if knowledge == "":
            continue
        if "entity_text" in prediction:
            if prediction["entity_text"].lower() in knowledge:
                is_correct = True

        if "answer_eval" in prediction:
            answer_eval = prediction["answer_eval"]
            if isinstance(answer_eval, dict) and "range" in answer_eval:
                range_vals = answer_eval["range"]
                if isinstance(range_vals, list) and len(range_vals) == 2:
                    numbers = re.findall(r"[-+]?\d*\.\d+|\d+", knowledge)
                    for num_str in numbers:
                        try:
                            num = float(num_str)
                            if range_vals[0] <= num <= range_vals[1]:
                                is_correct = True
                                break
                        except ValueError:
                            continue
            elif isinstance(answer_eval, list):
                for ans in prediction["answer_eval"]:
                    if ans.lower() in knowledge:
                        is_correct = True
                        break

        if isinstance(prediction["answer"], list) and prediction["answer"]:
            for ans in prediction["answer"]:
                if ans.lower() in knowledge:
                    is_correct = True
                    break

        if isinstance(prediction["answer"], str) and prediction["answer"]:
            if prediction["answer"].lower() in knowledge:
                is_correct = True

        correctness_flags.append(is_correct)
        if is_correct:
            correct += 1

    accuracy = correct / len(predictions) if predictions else 0
    return accuracy, correctness_flags


def evaluate_accuracy(predictions):
    """Evaluate accuracy of predictions and return correctness flags."""
    correct = 0
    correctness_flags = []

    for prediction in predictions:
        is_correct = False
        if prediction["label"] and prediction["label"] == prediction["prediction"][0]:
            is_correct = True

        if (
            prediction["label"]
            and prediction["label"] + ":" in prediction["prediction"]
        ):
            is_correct = True

        # Check if answer is a list or a string
        if "answer_eval" in prediction:
            answer_eval = prediction["answer_eval"]
            if isinstance(answer_eval, dict) and "range" in answer_eval:
                range_vals = answer_eval["range"]
                if isinstance(range_vals, list) and len(range_vals) == 2:
                    numbers = re.findall(r"[-+]?\d*\.\d+|\d+", prediction["prediction"])
                    for num_str in numbers:
                        try:
                            num = float(num_str)
                            if range_vals[0] <= num <= range_vals[1]:
                                is_correct = True
                                break
                        except ValueError:
                            continue
            elif isinstance(answer_eval, list):
                for ans in prediction["answer_eval"]:
                    if ans.lower() in prediction["prediction"]:
                        is_correct = True
                        break

        if isinstance(prediction["answer"], list) and prediction["answer"]:
            for ans in prediction["answer"]:
                if ans.lower() in prediction["prediction"].lower():
                    is_correct = True
                    break

        if isinstance(prediction["answer"], str) and prediction["answer"]:
            if prediction["answer"].lower() in prediction["prediction"].lower():
                is_correct = True

        correctness_flags.append(is_correct)
        if is_correct:
            correct += 1

    accuracy = correct / len(predictions) if predictions else 0
    return accuracy, correctness_flags


def search_and_format_knowledge(prompt, text_retriever, num=15, start=1):
    # Retrieve search results and scores
    retrieval_results, scores = text_retriever.search(
        prompt, num=num, return_score=True
    )

    # Ensure each retrieval result has a title and text
    for doc_item, score in zip(retrieval_results, scores):
        if "title" not in doc_item or "text" not in doc_item:
            parts = doc_item["contents"].split("\n")
            doc_item["title"] = parts[0]
            doc_item["text"] = "\n".join(parts[1:])

    # Format the retrieved passages into a readable string
    formatted_ref = ""
    for idx, passage in enumerate(retrieval_results, start):
        formatted_ref += f"Passage #{idx} Title: {passage['title']}\n"
        formatted_ref += f"Passage #{idx} Text: {passage['text']}\n\n"

    return formatted_ref


def main():
    args = parse_args()
    data = load_data(args.data)

    # VLM Agent
    if "gpt" in args.model:
        model_wrapper = OpenAIWrapper(args.model)
        model_name = args.model
    elif "openrouter:" in args.model:
        args.model = args.model.replace("openrouter:", "")
        model_wrapper = OpenAIWrapper(
            model=args.model,
            api_base="https://openrouter.ai/api/v1/chat/completions",
            key=os.environ.get("OPENROUTER_API_KEY", ""),
        )
        model_name = args.model.split("/")[1].replace(":free", "")
    elif "Qwen2.5-VL" in args.model or "Qwen2-VL" in args.model:
        model_wrapper = QWenWrapper(args.model)
        model_name = args.model.split("/")[1]
    elif "gemma" in args.model:
        model_wrapper = GemmaWrapper(args.model)
        model_name = args.model.split("/")[1]
    elif "Ovis2" in args.model:
        model_wrapper = OvisWrapper(args.model)
        model_name = args.model.split("/")[1]

    if not args.direct_answer:
        text_retriever = get_retriever(default_config)

    if args.image_search:
        retriever = FaissSearch(
            metadata_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_metadata.csv",
            #index_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_faiss.index",
            #model_name="hf-hub:timm/ViT-SO400M-16-SigLIP2-384",
            #text_model="Alibaba-NLP/gte-modernbert-base",
            index_path="/drl_nas2/ckddls1321/data/InfoSeek/mixed_index_large2.index",
            model_name="hf-hub:timm/ViT-gopt-16-SigLIP2-384",
            text_model="Qwen/Qwen3-Embedding-0.6B",
            search_caption=True,
            top_k=args.topk,
        )

    input_file_name = os.path.basename(args.data)
    dataset_name = os.path.basename(os.path.dirname(args.data))
    output_file_name = (
        f"{model_name}_{dataset_name}_{os.path.splitext(input_file_name)[0]}"
        f"_da_{args.direct_answer}_iter_{args.itercount}.csv"
    )
    if args.iterretgen:
        output_file_name = output_file_name.replace(".csv", "_iterret.csv")
    if args.update:
        output_file_name = output_file_name.replace(".csv", "_update.csv")
    if args.history:
        output_file_name = output_file_name.replace(".csv", "_history.csv")
    if args.ask:
        output_file_name = output_file_name.replace(".csv", "_ask3.csv")
    if args.image_search:
        output_file_name = output_file_name.replace(".csv", "_image_search.csv")
    if args.topk != 10:
        output_file_name = output_file_name.replace(".csv", f"_topk{args.topk}.csv")
    output_path = os.path.join("./outputs", output_file_name)

    temp_path = output_path.replace(".csv", ".pkl")
    predictions = load_intermediate_results(temp_path)

    start_index = len(predictions)
    print(f"Resuming from index {start_index}...")

    start_time = time.time()  # 시간 측정 시작
    total_retrieval_time = 0
    total_api_time = 0

    for i, item in enumerate(data[start_index:], start=start_index):
        question = item["question"]
        image_path = item["image_path"]
        entity_text = item.get("entity_text", "")  # Get choices, or None for Open-Ended
        answer_eval = item.get("answer_eval", "")  # Get choices, or None for Open-Ended
        label = item.get("label", "")  # Label may not exists for Open-Ended
        answer = item.get("answer", "")
        option_prompt = ""
        total_knowledge = ""
        total_prediction = ""
        if args.direct_answer:
            prompt = f"##Answer concisely to the question. \nQuestion: {question}\n"
            image_text_pairs = None
            passage_prompt = None
            if args.image_search:
                passage_prompt = (
                    "Here is relevant image and their corresponding description.\n\n"
                )
                image_text_pairs, distances = retriever.search(image_path)
            prediction = model_wrapper.get_prediction(
                image_path,
                prompt,
                passages=image_text_pairs,
                passage_prompt=passage_prompt,
            )
            total_prediction += prediction
        else:  # MI-RAG
            #################
            # Describe image-based question
            #################
            prompt = f"Question: {question}\n Concisely describe image which is relevant to question.\n"
            start_api = time.time()
            description = model_wrapper.get_prediction(
                image_path=image_path, prompt=prompt
            )
            total_api_time += time.time() - start_api

            if args.verbose:
                print(description)
            ###################
            # Retrieve knowledge related to question
            ###################
            prompt = f"Question: {question}\n{description}\n"
            start_ret = time.time()  # 시간 측정 시작
            knowledge = search_and_format_knowledge(
                prompt, text_retriever, num=2*args.topk
            )
            total_retrieval_time += time.time() - start_ret
            if args.verbose:
                print(knowledge)

            if "lens_result" in item:  # For Enclopedia VQA
                knowledge = item["lens_result"]
                knowledge = knowledge[: 64 * 1024]  # Truncate to 64k

            total_knowledge += knowledge
            image_text_pairs = None
            passage_prompt = None
            knowledge = ""
            if args.image_search:
                passage_prompt = "Here is relevant pairs of image and their corresponding description.\n"
                start_ret = time.time()  # 시간 측정 시작
                image_text_pairs, distances = retriever.search(
                    {
                        "image_path": image_path,
                        "caption": question,
                    },
                    top_k=args.topk,
                )
                total_retrieval_time += time.time() - start_ret
                for pair in image_text_pairs:
                    total_knowledge += f"Passage: {pair['caption']}\n"

            #################
            # Summarize knowledge and description of internal knowledge
            #################
            prompt = f"Question: {question}\nDescription: {description}\nKnowledge: {knowledge}\n{option_prompt}\n"
            prompt += f"Based on image, description and knowledge, summarize correct and relevant information with image and question.\n"
            start_api = time.time()  
            prediction = model_wrapper.get_prediction(
                image_path=image_path,
                prompt=prompt,
                passages=image_text_pairs,
                passage_prompt=passage_prompt,
            )
            total_api_time += time.time() - start_api
            total_prediction += f"Reasoning Record #0 :" + prediction
            if args.verbose:
                print(prediction)

            #################
            # iteration start
            #################
            for iter_num in range(1, args.itercount + 1):
                print(f"#### Iteration {iter_num} ####")
                prompt = f"Question: {question}\n{prediction}\n"
                start_ret = time.time()  
                knowledge = search_and_format_knowledge(
                    prompt, text_retriever, num=args.topk if args.ask else 2*args.topk
                )
                total_retrieval_time += time.time() - start_ret
                if args.ask:
                    ask_prompt = f"Question: {question}\nKnowledge: {total_prediction}\n\nPlease first analyze all the information in a section named Analysis (## Analysis). \
                        Generate two follow-up questions to search for additional information and helpful to confirm knowledge, in a section named Queries (## Queries). \n\
                        Your output should be in the following format: \n\
                        ## Analysis \
                        Analysis question and knowledge to ask context-specific queries that helps to address question. \
                        ## Queries \
                        Question 1: question 1. \n\
                        Question 2: question 2. \n."
                    start_api = time.time()  
                    questions = model_wrapper.get_prediction(
                        image_path,
                        ask_prompt,
                    )
                    total_api_time += time.time() - start_api
                    question1, question2 = extract_questions(questions)
                    if question1:
                        start_ret = time.time()  
                        knowledge += search_and_format_knowledge(
                            f"Question: {question1}\n",
                            text_retriever,
                            num=int(args.topk // 2),
                            start=args.topk,
                        )
                        total_retrieval_time += time.time() - start_ret
                    if question2:
                        start_ret = time.time() 
                        knowledge += search_and_format_knowledge(
                            f"Question: {question2}\n",
                            text_retriever,
                            num=int(args.topk // 2),
                            start=args.topk + int(args.topk // 2),
                        )
                        total_retrieval_time += time.time() - start_ret

                if args.verbose:
                    print(knowledge)
                total_knowledge += knowledge

                image_text_pairs = None
                passage_prompt = None
                if args.image_search:
                    passage_prompt = "Here is relevant pairs of image and their corresponding description.\n"
                    start_ret = time.time()  
                    image_text_pairs, distances = retriever.search(
                        {
                            "image_path": image_path,
                            "caption": f"Question: {question}\n{prediction}\n",
                        },
                        top_k=int(args.topk//2) if args.ask else args.topk,
                    )
                    total_retrieval_time += time.time() - start_ret
                    if args.ask:
                        if question1:
                            start_ret = time.time()  # 시간 측정 시작
                            additional_pairs, distances = retriever.search(
                                {
                                    "image_path": image_path,
                                    "caption": f"Question: {question1}\n",
                                }
                            )
                            total_retrieval_time += time.time() - start_ret
                            image_text_pairs += additional_pairs[: int(args.topk // 4)]
                        if question2:
                            start_ret = time.time()  # 시간 측정 시작
                            additional_pairs, distances = retriever.search(
                                {
                                    "image_path": image_path,
                                    "caption": f"Question: {question2}\n",
                                }
                            )
                            image_text_pairs += additional_pairs[: int(args.topk // 4)]
                            total_retrieval_time += time.time() - start_ret
                    for pair in image_text_pairs:
                        total_knowledge += f"Passage: {pair['caption']}\n"

                prompt = f"Question: {question}\n{knowledge}\n\n{option_prompt}\n"
                prompt += "Based on image and knowledge, summarize correct and relevant information with image and question.\n"
                start_api = time.time()  
                prediction = model_wrapper.get_prediction(
                    image_path=image_path,
                    prompt=prompt,
                    passages=image_text_pairs,
                    passage_prompt=passage_prompt,
                )
                total_api_time += time.time() - start_api
                total_prediction += f"Reasoning Record #{iter_num} :" + prediction
                if args.verbose:
                    print(prediction)

            prompt = f"Please answer the following question using the provided information and image.\n\nQuestion: {question}\nRelevant Knowledge: {total_prediction}\n\nBased on the information, provide a detailed answer to the question.\n"

            prediction = model_wrapper.get_prediction(
                image_path=image_path,
                prompt=prompt,
            )

            if args.verbose:
                print(prediction)

        predictions.append(
            {
                "question": question,
                "image_path": image_path,
                "label": label if label else "",
                "entity_text": entity_text if entity_text else "",
                "answer": answer,
                "answer_eval": answer_eval if answer_eval else "",
                "knowledge": total_knowledge,
                "total_pred": total_prediction,
                "prediction": prediction,
            }
        )

        elapsed = time.time() - start_time
        print(f"[Timer] MIRAG {i} iterations took {elapsed:.2f} seconds.")
        print(
            "[Timer] Total retrieval time: {:.2f} seconds.".format(total_retrieval_time)
        )
        print("[Timer] Total API time: {:.2f} seconds.".format(total_api_time))

        if (i + 1) % 5 == 0:
            save_intermediate_results(predictions, temp_path)
            print(f"Saved intermediate results at index {i + 1}.")

    elapsed = time.time() - start_time
    print(f"[Timer] MIRAG iterations took {elapsed:.2f} seconds.")
    print("[Timer] Total retrieval time: {:.2f} seconds.".format(total_retrieval_time))
    print("[Timer] Total API time: {:.2f} seconds.".format(total_api_time))

    # Evaluate accuracy and add correctness flags
    accuracy, correctness_flags = evaluate_accuracy(predictions)
    print(f"Accuracy: {accuracy:.2%}")
    # Add correctness flags to predictions
    for prediction, is_correct in zip(predictions, correctness_flags):
        prediction["correct"] = is_correct

    accuracy, correctness_flags = evaluate_retrieval(predictions)
    print(f"Retrieval R@15: {accuracy:.2%}")

    save_intermediate_results(predictions, temp_path)
    print(f"Saved intermediate results at last.")
    # Save predictions to CSV
    save_predictions(predictions, output_path)
    print(f"Predictions saved to {output_path}")


if __name__ == "__main__":
    main()


