import ast
import copy as cp
import traceback
from .prime_math import extract_answer, grade

def list_to_dict(lst):
    return {chr(65 + i): val for i, val in enumerate(lst)}

def can_infer(answer, choices):
    answer = str(answer)
    copt = can_infer_option(answer, choices)
    if copt:
        return choices[copt]
    else:
        return answer # 选项的内容

def can_infer_option(answer, choices):
    # Choices is a dictionary
    if 'Failed to obtain answer via API' in answer:
        return False

    reject_to_answer = [
        "Sorry, I can't help with images of people yet.",
        "I can't process this file.",
        "I'm sorry, but without the image provided",
        'Cannot determine the answer'
    ]
    for err in reject_to_answer:
        if err in answer:
            return 'Z'

    def count_choice(splits, choices, prefix='', suffix=''):
        cnt = 0
        for c in choices:
            if prefix + c + suffix in splits:
                cnt += 1
        return cnt

    answer_mod = cp.copy(answer)
    chars = '.()[],:;!*#{}'
    for c in chars:
        answer_mod = answer_mod.replace(c, ' ')

    splits = [x.strip() for x in answer_mod.split()]
    count = count_choice(splits, choices)

    if count == 1:
        for ch in choices:
            if 'A' in splits and len(splits) > 3:
                return False
            if ch in splits:
                return ch
    elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
        return False
    return False

def compute_score(solution_str: str, ground_truth: str, extra_info=None):
    model_answer = extract_answer(solution_str)
    if model_answer is None:
        return {
            "score": 0.,
            "format_score": 0.,
            "acc": False,
            "extracted_gt": ground_truth,
            "extracted_pred": None,
        }
    question_type = extra_info.get('category', 'multi_choice')
    if question_type == 'multi_choice':
        try:
            choices = extra_info.get("options", '[]')
        except Exception as e:
            print(f"Error parsing choices: {e}")
            traceback.print_exc()
            
        choices = list_to_dict(ast.literal_eval(choices))
        res = can_infer(model_answer, choices)
        is_correct = True if res.lower() == ground_truth.lower() else False
        if not is_correct:
            is_correct = grade(model_answer, ground_truth)
    else:
        is_correct = grade(model_answer, ground_truth)
            
    if is_correct:
        return {
            "score": 1.,
            "format_score": 1.,
            "acc": True,
            "extracted_gt": ground_truth,
            "extracted_pred": model_answer,
        }
    else:
        return {
            "score": 0.,
            "format_score": 1.,
            "acc": False,
            "extracted_gt": ground_truth,
            "extracted_pred": model_answer,
        }