import re
import numpy as np
from numbers import Number
from collections import defaultdict
from utils.cot_voting import majority_voting


def group_answers(responses):
    grouped = defaultdict(list)
    has_numeric = False

    # check if there is a string that can be eval-ed into a value
    for resp in responses:
        try:
            eval(resp)
            has_numeric = True
            break
        except:
            pass

    if has_numeric:
        # if there is a numeric type, then the numeric approximation comparison is performed
        for idx, resp in enumerate(responses):
            if resp in grouped:
                grouped[resp].append(idx)
                continue

            matched = False
            for existing_key in grouped:
                try:
                    if abs(eval(resp) - eval(existing_key)) < 1e-6:
                        grouped[existing_key].append(idx)
                        matched = True
                        break
                except:
                    pass

            if not matched:
                grouped[resp].append(idx)
    else:
        # Non-numeric types are grouped by the original string
        for idx, resp in enumerate(responses):
            grouped[resp].append(idx)

    return grouped


def remove_prefix(text, prefix):
    return text[len(prefix):] if text.startswith(prefix) else text


def match_answers(answers, expected_answers):
    """
    Compare the given answers to the expected value. Return a boolean list indicating
    whether each answer matches the expected value.

    :param answers: Either a single string or a list of strings representing user answers.
    :param expected_answers: A string or a numeric value representing the expected answer.
    :return: If 'answers' is a list, returns a list of booleans; if 'answers' is a single string, returns a single boolean.
    """
    # If 'expected_answers' does not match the DATA_RE_TEMPLATE, try evaluating it as a numeric value
    if len(re.findall(rf"^{DATA_RE_TEMPLATE}$", f"{expected_answers}")) == 0:
        try:
            expected_answers = eval(expected_answers)
        except Exception:
            pass

    # If 'answers' is not a list, convert it into a list and recurse for a single-element list
    if not isinstance(answers, list):
        return match_answers([answers], expected_answers)[0]

    # If 'expected_answers' is a numeric value, try evaluating each answer and compare numerically
    if isinstance(expected_answers, Number):
        try:
            results = [abs(eval(ans) - expected_answers) < 1e-6 for ans in answers]
        except Exception:
            results = [ans == str(expected_answers) for ans in answers]
    else:
        # Otherwise, compare answers as strings
        results = [ans == str(expected_answers) for ans in answers]

    return results


def compute_metric(pred_answers, final_answer):
    majority_results, majority_count = majority_voting(pred_answers)

    per_sample_correct = [match_answers(pred_answer, final_answer) for pred_answer in pred_answers]
    majority_corrects = match_answers(majority_results, final_answer)
    return per_sample_correct, majority_results, majority_corrects, majority_count

DATA_RE_TEMPLATE = "(\d\d\/\d\d\/\d\d\d\d)"

