
from lmms_eval.filters.extraction import ExtendedRegexFilter
from lmms_eval.filters.transformation import MapFilter
import re
import numpy as np


import logging
eval_logger = logging.getLogger("lmms-eval")

def get_task_instruction(dataset):
    if dataset in ['analogy', 'attribute', 'plot_code', 'visual_chain', 'sightseeing']:
        instr = 'Answer with a single word.'
    elif dataset in ['codeu', 'food', 'image_jigsaw']:
        instr = 'Answer with the option symbol.'
    elif dataset in ['arxiv']:
        instr = 'Answer with the paper title.'
    elif dataset in ['count']:
        instr = 'Answer with a single number.'
    elif dataset in ['3d_scene']:
        instr = 'The following images are different views of the same 3D scene. Answer with a single number.'
    
    return instr

def mirb_doc_to_text(doc, model_specific_prompt_kwargs=None):
    subset, question = doc['subset'], doc["questions"]
    task_instruction = get_task_instruction(subset)
    post_prompt = model_specific_prompt_kwargs["post_prompt"]
    pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
    return f"{pre_prompt}{task_instruction}{question}{post_prompt}"

def mirb_doc_to_visual(doc):
    image_list = [image.convert("RGB") for image in doc["image_list"]]
    return image_list 

def mirb_doc_to_target(doc):
    return doc["answers"]


def extract_numbers(string):
    """
    Exact all forms of numbers from a string with regex.
    https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L100
    """
    # Pattern for numbers with commas
    pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
    # Pattern for scientific notation
    pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
    # Pattern for simple numbers without commas
    pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"

    # Extract numbers with commas
    numbers_with_commas = re.findall(pattern_commas, string)
    # Extract numbers in scientific notation
    numbers_scientific = re.findall(pattern_scientific, string)
    # Extract simple numbers without commas
    numbers_simple = re.findall(pattern_simple, string)

    # Combine all extracted numbersz
    all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
    return all_numbers

def check_is_number(string):
    """
    Check if the given string a number.
    https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L65
    """
    try:
        float(string.replace(",", ""))
        return True
    except ValueError:
        # check if there's comma inside
        return False

def normalize_str(string):
    """
    Normalize the str to lower case and make them float numbers if possible.
    https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L76
    """
    # check if characters in the string

    # if number, numerize it.
    string = string.strip()

    is_number = check_is_number(string)

    if is_number:
        string = string.replace(",", "")
        string = float(string)
        # leave 2 decimal
        string = round(string, 2)
        return [string]
    else:  # it's likely to be a string
        # lower it
        string = string.lower()
        if len(string) == 1:
            return [" " + string, string + " "]  # avoid trivial matches
        return [string]

def parse_multi_choice_response(response):
    # here, we assume we have a list, in which each element is
    # a list of model responses for some particular input/target pair.
    # so we process each of these (same input/target response sets)
    # independently (and keep them a list.)
    filtered_resps = []
    option_letter_regex = re.compile(r"^\s*([A-Z])\.")
    match = option_letter_regex.match(response)
    if match:
        # If a match is found, append the matched letter
        filtered_resps = match.group(1)
    else:
        # If no match, return the original response
        filtered_resps = response
    return filtered_resps


def parse_open_response(response):
    """
    Parse the prediction from the generated response.
    Return a list of predicted strings or numbers.
    https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L122
    """

    # content = content.strip("\n").strip(".").strip(" ")
    def get_key_subresponses(response):
        key_responses = []
        response = response.strip().strip(".").lower()
        sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
        indicators_of_keys = [
            "could be ",
            "so ",
            "is ",
            "thus ",
            "therefore ",
            "final ",
            "answer ",
            "result ",
        ]
        key_responses = []
        for index, resp in enumerate(sub_responses):
            # if last one, accept it's an equation (the entire response can be just one sentence with equation)
            if index == len(sub_responses) - 1:
                indicators_of_keys.extend(["="])
            shortest_key_response = None  # the shortest response that may contain the answer (tail part of the response)
            for indicator in indicators_of_keys:
                if indicator in resp:
                    if not shortest_key_response:
                        shortest_key_response = resp.split(indicator)[-1].strip()
                    else:
                        if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
                            shortest_key_response = resp.split(indicator)[-1].strip()
                    # key_responses.append(resp.split(indicator)[1].strip())

            if shortest_key_response:
                # and it's not trivial
                if shortest_key_response.strip() not in [
                    ":",
                    ",",
                    ".",
                    "!",
                    "?",
                    ";",
                    ":",
                    "'",
                ]:
                    key_responses.append(shortest_key_response)
        if len(key_responses) == 0:  # did not found any
            return [response]
        return key_responses

    # pdb.set_trace()
    key_responses = get_key_subresponses(response)

    pred_list = key_responses.copy()  # keep the original string response
    for resp in key_responses:
        pred_list.extend(extract_numbers(resp))

    tmp_pred_list = []
    for i in range(len(pred_list)):
        tmp_pred_list.extend(normalize_str(pred_list[i]))
    pred_list = tmp_pred_list

    # remove duplicates
    pred_list = list(set(pred_list))

    return pred_list

def mirb_process_results(doc, results):
    pred = results[0]
    subset, answer = doc["subset"], doc["answers"]
    if answer in ['A', 'B', 'C', 'D', 'E']: # MCQ tasks
        parsed_pred = parse_multi_choice_response(pred)
    else:
        parsed_pred = parse_open_response(pred)
    task_type = doc["subset"]
    data_dict = {"question_id": doc["question_id"], "subset": task_type, "pred_answer": parsed_pred, "answers": doc["answers"]}
    return {f"mirb_score": data_dict}

def eval_multi_choice(gold_i, pred_i):
    """
    Evaluate a multiple choice instance.
    https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L175
    """
    correct = False
    # only they are exactly the same, we consider it as correct
    if isinstance(gold_i, list):
        for answer in gold_i:
            if answer == pred_i:
                correct = True
                break
    else:  # gold_i is a string
        if gold_i == pred_i:
            correct = True
    return correct


def eval_open(gold_i, pred_i):
    """
    Evaluate an open question instance
    https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L191
    """
    correct = False
    if isinstance(gold_i, list):
        # use float to avoid trivial matches
        norm_answers = []
        for answer in gold_i:
            norm_answers.extend(normalize_str(answer))
    else:
        norm_answers = normalize_str(gold_i)
    for pred in pred_i:  # pred is already normalized in parse response phase
        if isinstance(pred, str):  # if it's a string, then find if ans in the pred_i
            for norm_ans in norm_answers:
                # only see if the string answer in the string pred
                if isinstance(norm_ans, str) and norm_ans in pred:
                    if not correct:
                        correct = True
                    break
        else:  # it's a float number
            if pred in norm_answers:
                if not correct:
                    correct = True
                break
    return correct

def mirb_aggregation(results):
    task_num = {}
    score = 0
    task_score = {}
    for result in results:
        if result["subset"] not in task_score:
            task_score[result["subset"]] = 0
            task_num[result["subset"]] = 0

        if result['answers'] in ['A', 'B', 'C', 'D', 'E']: # MCQ tasks
            correct = eval_multi_choice(result["answers"], result["pred_answer"])
            task_score[result["subset"]] += correct
            score += correct
        else:
            correct = eval_open(result["answers"], result["pred_answer"])
            task_score[result["subset"]] += correct
            score += correct
        task_num[result["subset"]] += 1
    avg_score = score / len(results)

    task_score = {k : v / task_num[k] for k,v in task_score.items()}

    print("Performances for different subsets:")
    print("=" * 50)
    for k, v in task_score.items():
        print(f"{k} : {v:.2f}")
    print("=" * 50)

    # print across evaluation dimension
    groups = {
        "Knowledge": ["food", "sightseeing"],
        "Reasoning": ["codeu", "plot_code", "analogy", "3d_scene"],
        "Perception": ["image_jigsaw", "count", "attribute"],
        "Multi-Hop": ["visual_chain", "arxiv"]
    }

    # Compute the averages for each group
    averages_dict = compute_averages_from_task_scores(task_score, groups)

    # Print group performances
    print("Performances for different dimensions:")
    print("=" * 50)
    for group, score in averages_dict.items():
        print(f"{group} : {score:.2f}")
    print("=" * 50)
    
    return avg_score


# Compute averages for each group based on task scores
def compute_averages_from_task_scores(task_score, groups):
    averages = {}
    for group, features in groups.items():
        values = [task_score[feature] for feature in features]
        averages[group] = np.mean(values)
    return averages