import ast
import re
import copy as cp
from .prime_math import grade

def list_to_dict(lst):
    return {chr(65 + i): val for i, val in enumerate(lst)}

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[: len(left)] == left
        assert s[-1] == "}"
        return s[len(left) : -1]
    except:
        return None


def extract_boxed_answer(solution: str) -> str:
    """Extract the answer from inside a LaTeX \\boxed{} command"""
    solution = last_boxed_only_string(solution)
    solution = remove_boxed(solution)
    return solution


def extract_answer(passage: str) -> str:
    if "\\boxed" in passage:
        return extract_boxed_answer(passage)
    return None

def extract_options(text):
    """Extracts multiple-choice options from a question string."""
    options = []
    text = str(text)

    # First, try to find options like (A) ... (B) ...
    pattern1 = r'\((\w)\)\s([^()]+)'
    matches1 = re.findall(pattern1, text)
    if matches1:
        for option, description in matches1:
            # Ensure it's a standard option format
            if option.upper() in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']:
                options.append(description.strip())
        if options:
            return options

    # If the first pattern fails, try to find list-based options like A: [1, 2]
    pattern2 = r"^[A-Z]:\s*\[(.*?)\]"
    matches2 = re.findall(pattern2, text, re.MULTILINE)
    if matches2:
        for s in matches2:
            try:
                list_of_ints = [int(num.strip()) for num in s.split(',')]
                options.append(list_of_ints)
            except ValueError:
                # If conversion to int fails, it might be a list of strings
                # This part can be adjusted based on expected content
                list_of_strs = [str(item.strip()) for item in s.split(',')]
                options.append(list_of_strs)
        if options:
            return options
    # if both fails, then the options will like A. XX\nB. YY\n...
    pattern3 = r"^[A-Z]\.\s+(.*)"
    options = re.findall(pattern3, text, re.MULTILINE)
    if options:
        return [opt.strip() for opt in options if opt.strip()]
    return ['A', 'B', 'C', 'D']  # Default options if none found

def can_infer(answer, choices):
    answer = str(answer)
    copt = can_infer_option(answer, choices)
    if copt:
        return choices[copt]
    else:
        return answer # 选项的内容

def can_infer_option(answer, choices):
    # Choices is a dictionary
    if 'Failed to obtain answer via API' in answer:
        return False

    reject_to_answer = [
        "Sorry, I can't help with images of people yet.",
        "I can't process this file.",
        "I'm sorry, but without the image provided",
        'Cannot determine the answer'
    ]
    for err in reject_to_answer:
        if err in answer:
            return 'Z'

    def count_choice(splits, choices, prefix='', suffix=''):
        cnt = 0
        for c in choices:
            if prefix + c + suffix in splits:
                cnt += 1
        return cnt

    answer_mod = cp.copy(answer)
    chars = '.()[],:;!*#{}'
    for c in chars:
        answer_mod = answer_mod.replace(c, ' ')

    splits = [x.strip() for x in answer_mod.split()]
    count = count_choice(splits, choices)

    if count == 1:
        for ch in choices:
            if 'A' in splits and len(splits) > 3:
                return False
            if ch in splits:
                return ch
    elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
        return False
    return False

def compute_score(solution_str: str, ground_truth: str, extra_info=None):
    model_answer = extract_answer(solution_str)
    if model_answer is None:
        return {
            "score": 0.,
            "format_score": 0.,
            "acc": False,
            "extracted_gt": ground_truth,
            "extracted_pred": None,
        }
    category = extra_info.get('category', '')
    if category == 'multi-choice':
        options = extra_info.get("options", "")
        options = ast.literal_eval(options)
        if options == []:
            question = extra_info.get("question", "")
            options = extract_options(question)
            options = list_to_dict(options)
            res = can_infer(model_answer, options)
            is_correct = True if res.lower() == ground_truth.lower() else False
            if not is_correct:
                is_correct = grade(model_answer, ground_truth)
            if not is_correct and ground_truth in 'ABCD':
                is_correct = (options[ground_truth].lower() == model_answer.lower())
            if not is_correct:
                try:
                    model_answer = model_answer.replace('\n', '').replace(',', '')
                    ground_truth = ground_truth.replace('\n', '').replace(',', '')
                    is_correct = (model_answer.lower() == ground_truth.lower())
                except:
                    is_correct = False
        else:
            raise NotImplementedError
    else:
        # free-form
        is_correct = grade(model_answer, ground_truth)
        if not is_correct:
            # If the model answer is not correct, we can still check if it matches the ground truth
            is_correct = (model_answer.lower() == ground_truth.lower())
    if is_correct:
        return {
            "score": 1.,
            "format_score": 1.,
            "acc": True,
            "extracted_gt": ground_truth,
            "extracted_pred": model_answer,
        }
    else:
        return {
            "score": 0.,
            "format_score": 1.,
            "acc": False,
            "extracted_gt": ground_truth,
            "extracted_pred": model_answer,
        }
    