import re
from typing import Any, List, Optional

from Levenshtein import distance

from llm_mcts.tasks.math_vista.problem import MathVistaProblem


# The following code is based on this: https://github.com/lupantech/MathVista/blob/99fa993d4e3f659f8d93b7786502e9109e94d273/evaluation/build_query.py
SHOT_EXAMPLES = [
    # PID: 799
    {
        "question": "How much money does Ruth need to buy a baking dish, a casserole dish, and an ice cream scoop? (Unit: $)",
        "caption": "The image shows a table with a variety of items on it, including a baking dish, ice cream scoop, casserole dish, and rolling pin. The text in the image says:\n\n```\nbaking dish\n$4.00\nice cream scoop\n$6.00\ncasserole dish\n$3.00\nrolling pin\n$4.00\n```",
        "ocr": "[([5, 3], 'baking dish'), ([177, 5], '$4.00'), ([7, 41], 'ice cream scoop'), ([177, 37], '$6.00'), ([9, 69], 'casserole dish'), ([177, 69], '$3.00'), ([5, 98], 'rolling pin'), ([177, 101], '$4.00')]",
        "solution": """
Find the total cost of a baking dish, a casserole dish, and an ice cream scoop.\n\n$4.00 + $3.00 + $6.00 = $13.00\n\nRuth needs $13.00.
""",
        "code": """
baking_dish_price = 4.00
casserole_dish_price = 3.00
ice_cream_scoop_price = 6.00

ans = baking_dish_price + casserole_dish_price + ice_cream_scoop_price
print(ans)
""",
    },
    # PID: 681
    {
        "question": "What is the largest city in the nation where this plane is headquartered?",
        "choices": ["hong kong", "osaka", "shanghai", "tokyo"],
        "caption": 'The image shows a large passenger jet parked on a tarmac at an airport. The jet is white with red trim and has a red tail. It is sitting on top of a tarmac next to a building. The jet is being loaded with passengers and cargo. The text on the image says "Japan. Endless Discovery".',
        "solution": """
The caption mentions that the text on the image says "Japan. Endless Discovery". This indicates that the plane is headquartered in Japan.

Among the Japanese cities, Tokyo is the largest city.

Thus, the answer is D (tokyo).
""",
        "code": """
def largest_city(caption, choices):
    countries_largest_cities = {
        'Japan': 'tokyo',
        'China': 'shanghai'
    }

    if "Japan" in caption:
        country = 'Japan'
    elif "China" in caption:
        country = 'China'

    for choice in choices:
        if choice == countries_largest_cities[country]:
            return choice
    return ""

choices = ['hong kong', 'osaka', 'shanghai', 'tokyo']
caption = "The image shows a large passenger jet parked on a tarmac at an airport. The jet is white with red trim and has a red tail. It is sitting on top of a tarmac next to a building. The jet is being loaded with passengers and cargo. The text on the image says 'Japan. Endless Discovery'."

print(largest_city(caption, choices))
""",
    },
    # PID: 615
    {
        "question": "If two sides of a triangle measure 12 and 7, which of the following cannot be the perimeter of the triangle?",
        "choices": ["29", "34", "37", "38"],
        "caption": "The image shows a triangle with two sides labeled 7 and 12. The triangle is drawn on a white background. There is no text other than the labels.",
        "ocr": "[([70, 74], '7'), ([324, 74], '12')]",
        "solution": """
To determine which of the given perimeters cannot be possible for the triangle, we apply the triangle inequality theorem. The sum of any two sides of a triangle must be greater than the third side.

For the maximum possible value of the third side:
12 + 7 = 19

The minimum possible value for the third side:
12 - 7 = 5

The third side for each option:
(A) 29 - 12 - 7 = 10 (valid)
(B) 34 - 12 - 7 = 15 (valid)
(C) 37 - 12 - 7 = 18 (valid)
(D) 38 - 12 - 7 = 19 (invalid because it should be less than 19)

Thus, the answer is D.
""",
        "code": """
def is_valid_triangle(a, b, perimeter):
    # Given a and b, find the third side
    third_side = perimeter - a - b

    # Check triangle inequality
    if (a + b > third_side) and (a + third_side > b) and (b + third_side > a):
        return True
    return False

# Given sides
a = 12
b = 7

# Given perimeters
perimeters = [29, 34, 37, 38]

# Check which perimeter is not valid
for p in perimeters:
    if not is_valid_triangle(a, b, p):
        print(p)
""",
    },
]


def refine_caption(caption: Any) -> str:
    if isinstance(caption, str):
        nonsense = [
            "Sure. ",
            "Sure, I can do that.",
            "Sorry, I can't help with images of people yet.",
            "I can't process this file.",
            "I'm unable to help you with that, as I'm only a language model and don't have the necessary information or abilities.",
            "I'm not programmed to assist with that.",
            "Please let me know if you have any other questions.",
            "I hope this is helpful!",
            "I hope this helps!",
        ]
        for non in nonsense:
            caption = caption.replace(non, "").strip()
        caption = caption.replace("  ", " ").strip()
    else:
        caption = ""
    return caption


def refine_ocr(ocr: str) -> str:
    """
    [
        (
            [[161, 39], [766, 39], [766, 120], [161, 120]],
            'The spring force does',
            0.912845069753024
        ),
        ...
    ]
    -->
    [([161, 39], 'The spring force does'), ...]
    """
    try:
        ocr = eval(ocr)
        if len(ocr) > 0:
            ocr = [([int(e[0][0][0]), int(e[0][0][1])], e[1]) for e in ocr]
            ocr = str(ocr)
        else:
            ocr = ""
    except:
        ocr = ""
    return ocr


def create_one_query(
    problem: MathVistaProblem,
    examples: List[dict],
    shot_num: int = 0,
    shot_type: str = "solution",
    use_caption: bool = False,
    use_ocr: bool = False,
) -> str:
    if use_caption:  # TODO: add support for caption
        raise NotImplementedError("use_caption is not supported yet")
    if use_ocr:  # TODO: add support for OCR
        raise NotImplementedError("use_ocr is not supported yet")
    if shot_num != 0:  # TODO: check if this is rational
        raise NotImplementedError(
            "Few-shot examples are contained in the testmini subset"
        )

    ### [1] Demo prompt
    if shot_num == 0:
        demo_prompt = ""
    else:
        demos = []
        shot_num = min(shot_num, len(examples))
        for example in examples[:shot_num]:
            prompt = ""

            # question
            prompt += f"Question: {example['question']}"

            # choices
            if "choices" in example:
                texts = ["Choices:"]
                for i, choice in enumerate(example["choices"]):
                    texts.append(f"({chr(ord('A')+i)}) {choice}")
                prompt += "\n" + "\n".join(texts)

            # caption
            if use_caption:
                caption = example["caption"] if "caption" in example else ""
                if caption != "":
                    prompt += "\n" + f"Image description: {caption}"

            # ocr
            if use_ocr:
                ocr = example["ocr"] if "ocr" in example else ""
                if ocr != "":
                    prompt += "\n" + f"Image detected text: {ocr}"

            # solution
            if shot_type == "solution":
                solution = example["solution"].strip()
                prompt += "\n" + f"Solution: {solution}"

            # code
            if shot_type == "code":
                code = example["code"].strip()
                prompt += "\n" + f"Python code: {code}"

            demos.append(prompt)

        demo_prompt = "\n\n".join(demos)

    ### [2] Test query
    # problem info
    question = problem.question
    unit = problem.unit
    choices = problem.choices
    caption = problem.caption
    ocr = problem.ocr
    precision = problem.precision
    question_type = problem.question_type
    answer_type = problem.answer_type

    # hint
    if shot_type == "solution":
        if question_type == "multi_choice":
            assert answer_type == "text"
            hint_text = f"Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end."
        else:
            assert answer_type in ["integer", "float", "list"]
            if answer_type == "integer":
                hint_text = f"Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end."

            elif answer_type == "float" and precision == 1:
                hint_text = f"Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end."

            elif answer_type == "float" and precision == 2:
                hint_text = f"Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end."

            elif answer_type == "list":
                hint_text = f"Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end."
    else:
        assert shot_type == "code"
        hint_text = "Hint: Please generate a python code to solve the problem"

    # question
    question_text = f"Question: {question}"
    if unit:
        question_text += f" (Unit: {unit})"

    # choices
    if choices:
        # choices: (A) 1.2 (B) 1.3 (C) 1.4 (D) 1.5
        texts = ["Choices:"]
        for i, choice in enumerate(choices):
            texts.append(f"({chr(ord('A')+i)}) {choice}")
        choices_text = "\n".join(texts)
    else:
        choices_text = ""

    # caption
    caption_text = ""
    if use_caption and caption != "":
        caption_text = f"Image description: {caption}"

    # ocr
    ocr_text = ""
    if use_ocr and ocr != "":
        ocr_text = f"Image detected text: {ocr}"

    # prompt
    if shot_type == "solution":
        prompt = "Solution: "
    else:
        assert shot_type == "code"
        prompt = "Python code: "

    elements = [question_text, choices_text, caption_text, ocr_text, hint_text, prompt]
    test_query = "\n".join([e for e in elements if e != ""])

    ### [3] Final query
    query = demo_prompt + "\n\n" + test_query
    query = query.strip()
    return query


# The following code is based on this: https://github.com/lupantech/MathVista/blob/99fa993d4e3f659f8d93b7786502e9109e94d273/evaluation/prompts/ext_ans.py
# pids = 852,  104,  824,  506,  540
EXTRACTION_DEMO_PROMPT = """
Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.

Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.
Question: Which number is missing?

Model response: The number missing in the sequence is 14.

Extracted answer: 14

Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.
Question: What is the fraction of females facing the camera?

Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.

Extracted answer: 0.6

Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.
Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $)

Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.

Extracted answer: 1.45

Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
Question: Between which two years does the line  graph saw its maximum peak?

Model response: The line graph saw its maximum peak between 2007 and 2008.

Extracted answer: [2007, 2008]

Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5

Model response: The correct answer is (B) 8/11.

Extracted answer: B
"""


# The following code is based on this: https://github.com/lupantech/MathVista/blob/99fa993d4e3f659f8d93b7786502e9109e94d273/evaluation/extract_answer.py
def create_test_prompt(demo_prompt: str, query: str, response: str) -> str:
    demo_prompt = demo_prompt.strip()
    test_prompt = f"{query}\n\n{response}"
    full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
    return full_prompt


def create_extract_answer_prompt(
    response: str, problem: MathVistaProblem, quick_extract: bool = False
) -> str:
    question_type = problem.question_type
    answer_type = problem.answer_type
    choices = problem.choices
    query = problem.query
    pid = problem.pid

    if response == "":
        return ""

    if question_type == "multi_choice" and response in choices:
        return response

    if answer_type == "integer":
        try:
            extraction = int(response)
            return str(extraction)
        except Exception as e:
            pass

    if answer_type == "float":
        try:
            extraction = str(float(response))
            return extraction
        except Exception as e:
            pass

    # quick extraction
    if quick_extract:
        # The answer is "text". -> "text"
        try:
            result = re.search(r'The answer is "(.*)"\.', response)
            if result:
                extraction = result.group(1)
                return extraction
        except Exception as e:
            pass

    # general extraction
    full_prompt = create_test_prompt(EXTRACTION_DEMO_PROMPT, query, response)
    return full_prompt


# The following code is based on this: https://github.com/lupantech/MathVista/blob/99fa993d4e3f659f8d93b7786502e9109e94d273/evaluation/calculate_score.py
def get_most_similar(prediction, choices):
    """
    Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
    """
    distances = [distance(prediction, choice) for choice in choices]
    ind = distances.index(min(distances))
    return choices[ind]
    # return min(choices, key=lambda choice: distance(prediction, choice))


def normalize_extracted_answer(
    extraction: Any,
    choices: Optional[List[str]],
    question_type: str,
    answer_type: str,
    precision: Optional[int],
    ignore_empty_extractions: bool = False,
) -> Any:
    """
    Normalize the extracted answer to match the answer type
    """
    if question_type == "multi_choice":
        # make sure the extraction is a string
        if isinstance(extraction, str):
            extraction = extraction.strip()
        else:
            try:
                extraction = str(extraction)
            except Exception:
                extraction = ""

        # if the extraction is empty, return None
        if ignore_empty_extractions and not extraction:
            return None

        # extract "A" from "(A) text"
        letter = re.findall(r"\(([a-zA-Z])\)", extraction)
        if len(letter) > 0:
            extraction = letter[0].upper()

        sequential_characters = [chr(ord("A") + i) for i in range(len(choices))]

        # if model output a character, use it as index of available choices
        if extraction in sequential_characters:
            option_index = sequential_characters.index(extraction)
            normalized_extraction = choices[option_index]
        else:
            # select the most similar option
            normalized_extraction = get_most_similar(extraction, choices)
        assert normalized_extraction in choices

    elif answer_type == "integer":
        try:
            normalized_extraction = str(int(float(extraction)))
        except Exception:
            normalized_extraction = None

    elif answer_type == "float":
        try:
            normalized_extraction = str(round(float(extraction), precision))
        except Exception:
            normalized_extraction = None

    elif answer_type == "list":
        try:
            normalized_extraction = str(extraction)
        except Exception:
            normalized_extraction = None

    return normalized_extraction


def safe_equal(prediction, answer):
    """
    Check if the prediction is equal to the answer, even if they are of different types
    """
    try:
        if prediction == answer:
            return True
        return False
    except Exception as e:
        return False


def get_acc_with_contion(res_pd, key, value):
    if key == "skills":
        # if value in res_pd[key]:
        total_pd = res_pd[res_pd[key].apply(lambda x: value in x)]
    else:
        total_pd = res_pd[res_pd[key] == value]

    correct_pd = total_pd[total_pd["true_false"] == True]
    acc = len(correct_pd) / len(total_pd)

    return len(correct_pd), len(total_pd), acc
