import argparse
from collections.abc import Iterable
import re
import pandas as pd
import string
import ast
import os
import sys
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.nn import functional as F


def save_predictions(predictions, output_path):
    df = pd.DataFrame(predictions)
    df.to_csv(output_path, index=False)


bem_model = None
bem_tokenizer = None

_DIGIT_MAP = {
    "none": "0",
    "zero": "0",
    "one": "1",
    "two": "2",
    "three": "3",
    "four": "4",
    "five": "5",
    "six": "6",
    "seven": "7",
    "eight": "8",
    "nine": "9",
    "ten": "10",
    "entailment": "yes",
    "true": "yes",
    "contradiction": "no",
    "false": "no",
}
_CONTRACTIONS = {
    "aint": "ain't",
    "arent": "aren't",
    "cant": "can't",
    "couldve": "could've",
    "couldnt": "couldn't",
    "couldn'tve": "couldn't've",
    "couldnt've": "couldn't've",
    "didnt": "didn't",
    "doesnt": "doesn't",
    "dont": "don't",
    "hadnt": "hadn't",
    "hadnt've": "hadn't've",
    "hadn'tve": "hadn't've",
    "hasnt": "hasn't",
    "havent": "haven't",
    "hed": "he'd",
    "hed've": "he'd've",
    "he'dve": "he'd've",
    "hes": "he's",
    "howd": "how'd",
    "howll": "how'll",
    "hows": "how's",
    "Id've": "I'd've",
    "I'dve": "I'd've",
    "Im": "I'm",
    "Ive": "I've",
    "isnt": "isn't",
    "itd": "it'd",
    "itd've": "it'd've",
    "it'dve": "it'd've",
    "itll": "it'll",
    "let's": "let's",
    "maam": "ma'am",
    "mightnt": "mightn't",
    "mightnt've": "mightn't've",
    "mightn'tve": "mightn't've",
    "mightve": "might've",
    "mustnt": "mustn't",
    "mustve": "must've",
    "neednt": "needn't",
    "notve": "not've",
    "oclock": "o'clock",
    "oughtnt": "oughtn't",
    "ow's'at": "'ow's'at",
    "'ows'at": "'ow's'at",
    "'ow'sat": "'ow's'at",
    "shant": "shan't",
    "shed've": "she'd've",
    "she'dve": "she'd've",
    "she's": "she's",
    "shouldve": "should've",
    "shouldnt": "shouldn't",
    "shouldnt've": "shouldn't've",
    "shouldn'tve": "shouldn't've",
    "somebody'd": "somebodyd",
    "somebodyd've": "somebody'd've",
    "somebody'dve": "somebody'd've",
    "somebodyll": "somebody'll",
    "somebodys": "somebody's",
    "someoned": "someone'd",
    "someoned've": "someone'd've",
    "someone'dve": "someone'd've",
    "someonell": "someone'll",
    "someones": "someone's",
    "somethingd": "something'd",
    "somethingd've": "something'd've",
    "something'dve": "something'd've",
    "somethingll": "something'll",
    "thats": "that's",
    "thered": "there'd",
    "thered've": "there'd've",
    "there'dve": "there'd've",
    "therere": "there're",
    "theres": "there's",
    "theyd": "they'd",
    "theyd've": "they'd've",
    "they'dve": "they'd've",
    "theyll": "they'll",
    "theyre": "they're",
    "theyve": "they've",
    "twas": "'twas",
    "wasnt": "wasn't",
    "wed've": "we'd've",
    "we'dve": "we'd've",
    "weve": "we've",
    "werent": "weren't",
    "whatll": "what'll",
    "whatre": "what're",
    "whats": "what's",
    "whatve": "what've",
    "whens": "when's",
    "whered": "where'd",
    "wheres": "where's",
    "whereve": "where've",
    "whod": "who'd",
    "whod've": "who'd've",
    "who'dve": "who'd've",
    "wholl": "who'll",
    "whos": "who's",
    "whove": "who've",
    "whyll": "why'll",
    "whyre": "why're",
    "whys": "why's",
    "wont": "won't",
    "wouldve": "would've",
    "wouldnt": "wouldn't",
    "wouldnt've": "wouldn't've",
    "wouldn'tve": "wouldn't've",
    "yall": "y'all",
    "yall'll": "y'all'll",
    "y'allll": "y'all'll",
    "yall'd've": "y'all'd've",
    "y'alld've": "y'all'd've",
    "y'all'dve": "y'all'd've",
    "youd": "you'd",
    "youd've": "you'd've",
    "you'dve": "you'd've",
    "youll": "you'll",
    "youre": "you're",
    "youve": "you've",
}

_PUNCTUATION_CHARACTERS = string.punctuation + "‘’´`_"


def preprocess_answer(
    answer,
    punctuation_characters=_PUNCTUATION_CHARACTERS,
    replacement_character="",
):
    """Function to preprocess VQA answers."""

    def remove_articles(s):
        """Remove common articles and prefixes in the answer."""
        return re.sub(r"\b(the answer is|a|an|the)\b", " ", s)

    def replace_punctuation(s):
        """Replace punctuation characters."""
        to_replace = set(punctuation_characters)
        return "".join(replacement_character if c in to_replace else c for c in s)

    def white_space_fix(s):
        """Remove superfluous whitespace."""
        return " ".join(s.split())

    def remove_llm_span_prefix(answer, prefix="<extra_id_0> "):
        """Remove span prefix added by some LLM."""
        if answer.startswith(prefix):
            return answer.replace(prefix, replacement_character)
        return answer

    def standarize_digits_and_contractions(s):
        """Standarize the representation of some digits and common contractions."""
        output = []
        tmp = s.split()
        for w in tmp:
            w = _DIGIT_MAP.get(w, w)
            w = _CONTRACTIONS.get(w, w)
            output.append(w)
        return " ".join(output)

    answer = answer.lower().replace("\n", " ").replace("\t", " ").strip()
    answer = remove_llm_span_prefix(answer)
    answer = replace_punctuation(answer)
    answer = remove_articles(answer)
    answer = standarize_digits_and_contractions(answer)
    answer = white_space_fix(answer)

    return answer


def slice_knowledge_by_iter(knowledge: str, iter_n: int) -> str:
    """
    Split `knowledge` into blocks at each "Passage 1:" line.
    Return only the Nth block (1-based). If N is out of range, returns "".
    """
    if ":" in iter_n:
        start = int(iter_n.split(":")[0])
        end = int(iter_n.split(":")[1])
    else:
        start, end = int(iter_n), int(iter_n) + 1
    # print(start, end)
    # print(len(knowledge.split("passage #1 title:")[1:]))
    blocks = knowledge.split("Passage #1 Title:")[1:][start:end]
    return "\n".join(blocks)


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate CSV or PKL Predictions")
    parser.add_argument(
        "--csv",
        type=str,
        help="Path to the CSV file with predictions to evaluate",
    )
    parser.add_argument(
        "--pkl",
        type=str,
        help="Path to the PKL file with saved predictions to evaluate",
    )
    parser.add_argument(
        "--bem",
        action="store_true",
        help="Use BEM model (kortukov/answer-equivalence-bem) for items failing initial heuristic checks.",
    )
    parser.add_argument(
        "--verbose", action="store_true", help="Verbose output during evaluation"
    )
    parser.add_argument(
        "--vqa-score",
        action="store_true",
        help="Use VQA scoring metric when 'answer' is a list. Metric: Acc(ans) = min(#humans that said ans / 3, 1). If 10 human answers provided, score is averaged over 10c9 subsets.",
    )
    parser.add_argument(
        "--only-multimodal",
        action="store_true",
        help="Only evaluate multimodal knowledge recall",
    )
    parser.add_argument(
        "--only-text",
        action="store_true",
        help="Only evaluate text knowoledge recall",
    )
    parser.add_argument(
        "--iter",
        type=str,
        default=None,
        help="Number of iteration that is interested in",
    )
    return parser.parse_args()


def load_predictions(args):
    if args.pkl:
        with open(args.pkl, "rb") as f:
            predictions = pickle.load(f)
        print(f"Loaded {len(predictions)} predictions from pickle: {args.pkl}")
    elif args.csv:
        df = pd.read_csv(args.csv)
        predictions = df.to_dict(orient="records")
        print(f"Loaded {len(predictions)} predictions from CSV: {args.csv}")
    else:
        raise ValueError(
            "Please provide either --csv or --pkl with a path to load predictions."
        )
    return predictions


def extract_text(knowledge):
    """
    Extracts only text-based knowledge passages from the knowledge string.
    This function finds all multi-line blocks that start with "Passage #[number] Title:"
    and captures all content until the next passage marker or the end of the string.
    """
    # This pattern finds a block starting with "Passage #[num] Title:" and uses a
    # non-capturing lookahead (?...) to correctly handle multi-line content.
    pattern = r"Passage #\d+ Title:.*?Text:.*?(?=(?:Passage #|Passage:|\Z))"

    # re.DOTALL lets '.' match newline characters. re.IGNORECASE makes it case-insensitive.
    text_passages = re.findall(pattern, knowledge, re.DOTALL | re.IGNORECASE)
    # text_passages = "\n".join(text_passages)
    number_pattern = r"Passage #(\d+)"
    passages = []
    for passage in text_passages:
        match = re.search(number_pattern, passage, re.IGNORECASE)
        if match:
            passage_num = int(match.group(1))
            # if 1 <= passage_num <= 10: # For Query Expansion
            #     passages.append(passage)
            if 11 <= passage_num <= 20:  # For Query Generation
                passages.append(passage)
    text_passages = "\n".join(passages)
    return text_passages


def extract_multimodal(knowledge):
    """
    Extracts only multimodal knowledge passages from the knowledge string.
    This function processes the knowledge line by line.
    """
    multimodal_passages = []
    expansion_passages = []
    generation_passages = []
    # Split the knowledge into individual lines to check them one by one.
    passage_count = 0
    for line in knowledge.splitlines():
        clean_line = line.lower().lstrip()
        if clean_line.startswith("passage:") and "title:" not in clean_line:
            if passage_count % 9 < 5:
                expansion_passages.append(line)
            else:
                generation_passages.append(line)
            passage_count += 1
            multimodal_passages.append(line)
    # return "\n".join(expansion_passages)
    return "\n".join(generation_passages)
    # return "\n".join(multimodal_passages)


def evaluate_retrieval(predictions, args):
    """Evaluate accuracy of predictions and return correctness flags."""
    correct = 0
    correctness_flags = []

    for prediction in predictions:
        is_correct = False
        knowledge = str(prediction.get("knowledge", ""))
        if args.iter is not None:
            knowledge = slice_knowledge_by_iter(knowledge, args.iter)
        if args.only_multimodal:
            if args.only_text:
                text_knowledge = extract_text(knowledge)
            knowledge = extract_multimodal(knowledge)
            if args.only_text:
                knowledge += text_knowledge
        elif args.only_text:
            knowledge = extract_text(knowledge)
        knowledge = knowledge.replace(f"Passage:", "")
        for i in range(1, 40):
            knowledge = knowledge.replace(f"Passage #{i}", "")
        knowledge = knowledge.lower().strip()
        answer = str(prediction.get("answer", ""))
        answer_eval = str(prediction.get("answer_eval", ""))
        entity_text = str(prediction.get("entity_text", "")).lower()
        if "[" in entity_text:
            entity_text = ast.literal_eval(entity_text)
        if "[" in answer:
            answer = ast.literal_eval(answer)
        if "[" in answer_eval:
            answer_eval = ast.literal_eval(answer_eval)
        if not knowledge or knowledge == "nan":
            correctness_flags.append(False)
            continue

        if isinstance(entity_text, list):  # Enclopedia
            for entity in entity_text:
                if "|" in entity:
                    names = entity.split("|")  # Some cases have |, (2-hop case)
                    all_names_exist = all(
                        str(name).lower() in knowledge for name in names
                    )
                    is_correct = all_names_exist
                elif entity.lower() in knowledge:
                    is_correct = True
                    break
        elif entity_text and entity_text != "nan" and entity_text in knowledge:
            is_correct = True  # Infoseek
        elif (
            answer_eval and answer_eval != "nan"
        ):  # If entity does not exist, we measure pseudo relevance recall
            if isinstance(answer_eval, list):
                for ans in answer_eval:
                    if str(ans).lower() in knowledge:
                        is_correct = True
                        break
        elif isinstance(answer, list):
            for ans in answer:
                if str(ans).lower() in knowledge:
                    is_correct = True
                    break
        elif isinstance(answer, str):
            if answer.lower() in knowledge:
                is_correct = True

        correctness_flags.append(is_correct)
        if is_correct:
            correct += 1

    accuracy = correct / len(predictions) if predictions else 0
    return accuracy, correctness_flags


def evaluate_accuracy(predictions, args):
    """Evaluate accuracy of predictions and return correctness flags."""
    correct = 0
    correctness_flags = []
    is_error = 0

    for pred in predictions:
        is_correct = False
        label = pred.get("label")
        question = str(pred.get("question", ""))
        prediction = str(pred.get("prediction", ""))
        answer = str(pred.get("answer", ""))
        ans_eval = str(pred.get("answer_eval", ""))
        if "[" in answer:
            answer = ast.literal_eval(answer)
        if "[" in ans_eval or "{" in ans_eval:
            ans_eval = ast.literal_eval(ans_eval)

        prediction = prediction.replace(f"Passage:", "")
        for i in range(1, 40):
            prediction = prediction.replace(f"Passage #{i}", "")

        # label match
        if pd.notna(label) and label:
            if str(label) == prediction[: len(str(label))]:
                is_correct = True
            elif f"{label}:" in prediction:
                is_correct = True

        # numeric range eval for InfoSeek Human
        if isinstance(ans_eval, dict) and "range" in ans_eval:
            low = ans_eval["range"][0]
            high = ans_eval["range"][1]
            nums = re.findall(r"[-+]?\d*\.?\d+|\d+", prediction)
            for num in nums:
                try:
                    if low <= float(num) <= high:
                        is_correct = True
                        break
                except Exception:
                    pass

        # list of answers
        if isinstance(ans_eval, list):
            for ans in ans_eval:
                if str(ans).lower() in prediction.lower():
                    is_correct = True
                    break

        if args.vqa_score and isinstance(answer, list) and answer:  # For OK-VQA
            scores_from_subsets = []
            num_human_answers = len(answer)
            for i in range(num_human_answers):
                subset_human_answers = answer[:i] + answer[i + 1 :]
                num_matches_in_subset = 0
                for h_ans_in_subset in subset_human_answers:
                    if preprocess_answer(h_ans_in_subset) in prediction.lower():
                        num_matches_in_subset += 1
                score_for_this_subset = min(1.0, num_matches_in_subset / 3.0)
                scores_from_subsets.append(score_for_this_subset)
            correct += sum(scores_from_subsets) / len(scores_from_subsets)
            is_correct = False
        elif isinstance(answer, list):
            for ans in answer:
                if ans.isalpha() and len(ans) < 4:  # 3 letters may accidently included
                    # print(ans)
                    continue
                if isinstance(ans, str) and "&&" in ans:  # Multi-answer case
                    sub_answers = ans.split("&&")
                    found_sub_answers = 0
                    for sub_ans in sub_answers:
                        processed_sub_ans = preprocess_answer(sub_ans)
                        if processed_sub_ans in prediction.lower():
                            found_sub_answers += 1
                    if found_sub_answers / len(sub_answers) > 0.5:
                        is_correct = True
                        break

                if str(ans).lower() in prediction.lower():
                    is_correct = True

                # Numeric exact match if ans is numeric
                if ans_eval == "":  # Empty case when numeric for infoseek
                    try:
                        ans_val = float(ans)  # if ans is numeric
                        nums = re.findall(
                            r"[-+]?\d*\.?\d+|\d+", prediction
                        )  # Find all numbers in prediction
                        for num in nums:
                            if abs(float(num) - ans_val) <= 0.1 * ans_val:
                                #print(prediction)
                                is_correct = True
                                break
                        else:
                            is_correct = False
                    except (ValueError, TypeError):  # Only answer is numeric
                        pass

        elif isinstance(answer, str) and answer:
            if answer.lower() in prediction.lower():
                is_correct = True
        else:
            print("Not should happen...")

        if args.bem:  # Use BEM model as fallback
            is_correct = run_bem_evaluation(question, answer, prediction)

        if "cannot be" in prediction.lower():  # Model refused to answer
            is_correct = False
        if "Error" in prediction.lower():  # Model include error message
            is_correct = False
            is_error += 1

        correctness_flags.append(is_correct)
        if is_correct:
            correct += 1

    print(f"VLM API error in final answer : {is_error}")

    accuracy = correct / len(predictions) if predictions else 0
    return accuracy, correctness_flags


def initialize_bem_model_and_transformers():
    """Initializes the BEM tokenizer and model from Hugging Face."""
    global bem_tokenizer, bem_model
    print(
        "Initializing BEM model and tokenizer (kortukov/answer-equivalence-bem)... This may take a moment."
    )
    try:
        bem_tokenizer = AutoTokenizer.from_pretrained("kortukov/answer-equivalence-bem")
        bem_model = AutoModelForSequenceClassification.from_pretrained(
            "kortukov/answer-equivalence-bem"
        )
        if torch.cuda.is_available():
            bem_model = bem_model.to("cuda")
        print("BEM model and tokenizer initialized successfully.")
    except Exception as e:
        print(f"Error initializing BEM model from Hugging Face: {e}")
        raise


def run_bem_evaluation(question, reference, candidate):
    """
    Performs answer equivalence checking using the BEM model.
    Returns a tuple: (is_equivalent_decision, equivalence_probability)
    """
    global bem_tokenizer, bem_model
    if not all([bem_tokenizer, bem_model]):
        raise RuntimeError("BEM components not initialized properly.")

    text = f"[CLS] {candidate} [SEP]"
    text_pair = f"{reference} [SEP] {question} [SEP]"

    inputs = bem_tokenizer(
        text=text,
        text_pair=text_pair,
        add_special_tokens=False,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    inputs = {k: v.to(bem_model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = bem_model(**inputs)
    logits = outputs.logits
    probabilities = F.softmax(logits, dim=-1).cpu().numpy()
    return probabilities[0][1] > 0.5


def main():
    args = parse_args()
    predictions = load_predictions(args)

    if args.bem:
        initialize_bem_model_and_transformers()

    # Evaluate classification accuracy
    acc, flags = evaluate_accuracy(predictions, args)
    print(f"Accuracy: {acc:.2%}")
    if args.verbose:
        for p, flag in zip(predictions, flags):
            if flag:
                print(
                        f"##Q: {p.get('question','N/A')}\n##Image Path:{p.get('image_path','N/A')}\n##Answer: {p.get('answer','N/A')} \n##Answer_eval: {p.get('answer_eval','N/A')} \n##Correct: {flag} \n##Reasoning Records: {p.get('total_pred','N/A')}\n##Pred: {p.get('prediction','N/A')}"
                )

    # Evaluate retrieval performance
    r15, flags_ret = evaluate_retrieval(predictions, args)
    # if args.verbose:
    #     for p, flag in zip(predictions, flags_ret):
    #         print(
    #             f"Q: {p.get('question','N/A')}, Entity_text: {p.get('entity_text','N/A')} Answer: {p.get('answer','N/A')}, Correct: {flag}, Knowledge: {p.get('knowledge','N/A')}"
    #         )
    print(f"R@15: {r15:.2%}")
    if args.pkl and not os.path.exists(args.pkl.replace(".pkl", ".csv")):
        save_predictions(predictions, args.pkl.replace(".pkl", ".csv"))


if __name__ == "__main__":
    main()
