import pandas as pd
from math_verify import parse, verify
from math import prod

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    # compute for each observed answer
    # for x in set(answers):
    #     # confidences for x
    #     confs_x    = [c for a, c in zip(answers, confidences) if a == x]
    #     # confidences not for x
    #     others     = [c for a, c in zip(answers, confidences) if a != x]

    #     prod_correct = prod(confs_x) if confs_x else 1.0
    #     prod_penalty = prod((1 - c) / U for c in others) if others else 1.0

    #     scores[x] = prod_correct * prod_penalty

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0

    # print(scores)
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores

def most_common_math(
    answers,
    n = None
):
    """
    Count and group answers up to math-equivalence *after* normalization.
    Returns a list of (representative, count), sorted descending.
    """
    groups = []

    for ans in answers:
        placed = False
        for i, (rep, cnt) in enumerate(groups):
            if is_math_equiv(rep, ans):
                groups[i] = (rep, cnt + 1)
                placed = True
                break
        if not placed:
            # start a new group; use the normalized form as the rep
            groups.append((ans, 1))

    groups.sort(key=lambda x: x[1], reverse=True)
    return groups if n is None else groups[:n]


df = pd.read_csv("./Results/skills_3/MATH500/self_con_seed63_budget16_acc82.6_models[Qwen]_v2.csv")

import ast

all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]
all_gold_ans = list(df['gold_answer'])

for i in range(1, 17):
    final_predictions = []
    for answers in all_final_ans:
        if answers:
            final_predictions.append(most_common_math(answers[:i])[0][0])
        else:
            final_predictions.append("")

    correctness = [is_math_equiv(final_answer, str(all_gold_ans[i])) for i, final_answer in enumerate(final_predictions)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)

    print(f"Number of LLM Calls: {i}, ACC: {acc}")