import pandas as pd
from math_verify import parse, verify
from math import prod
import argparse


def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    # compute for each observed answer
    # for x in set(answers):
    #     # confidences for x
    #     confs_x    = [c for a, c in zip(answers, confidences) if a == x]
    #     # confidences not for x
    #     others     = [c for a, c in zip(answers, confidences) if a != x]

    #     prod_correct = prod(confs_x) if confs_x else 1.0
    #     prod_penalty = prod((1 - c) / U for c in others) if others else 1.0

    #     scores[x] = prod_correct * prod_penalty

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0

    # print(scores)
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores

def most_common_math(
    answers,
    n = None
):
    """
    Count and group answers up to math-equivalence *after* normalization.
    Returns a list of (representative, count), sorted descending.
    """
    groups = []

    for ans in answers:
        placed = False
        for i, (rep, cnt) in enumerate(groups):
            if is_math_equiv(rep, ans):
                groups[i] = (rep, cnt + 1)
                placed = True
                break
        if not placed:
            # start a new group; use the normalized form as the rep
            groups.append((ans, 1))

    groups.sort(key=lambda x: x[1], reverse=True)
    return groups if n is None else groups[:n]


# df = pd.read_csv("./Results/skills_3/AIME24/self_con_seed42_budget16_acc80.0_models[QwenR1]_v2.csv")

import ast


if __name__ == "__main__":  

    parser = argparse.ArgumentParser(description='Analyze thresholds for model switching')
    parser.add_argument('csv_file', help='Path to the CSV file')
    parser.add_argument('--window_size', nargs='+', type=int, default=[3], 
                       help='List of window sizes to evaluate (default: 3)')
    args = parser.parse_args()

    df = pd.read_csv(args.csv_file)

    all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]
    all_gold_ans = list(df['gold_answer'])

    for w in args.window_size:
        final_predictions = []
        num_llm_calls = 0
        for answers in all_final_ans:
            curr_candidates = []
            for i in range(len(answers) // w + min(1, len(answers) % w)):
                curr_candidates.extend(answers[i*w:(i+1)*w])
                if most_common_math(answers[i*w:(i+1)*w])[0][1] == w:
                    break
            final_predictions.append(most_common_math(curr_candidates)[0][0])
            num_llm_calls += len(curr_candidates)

        correctness = [is_math_equiv(final_answer, str(all_gold_ans[i])) for i, final_answer in enumerate(final_predictions)]
        acc = round(sum(correctness) / len(correctness) * 100, 2)

        print(f"Average Number of LLM Calls: {num_llm_calls / len(all_final_ans)}, ACC: {acc}, Window Size: {w}")