import json
import random

try:
    from math_verify.errors import TimeoutException
    from math_verify.metric import math_metric
    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
    print("To use Math-Verify, please install it first by running `pip install math-verify`.")

import verl.utils.reward_score.math as verl_math


# def compute_pass_at(data, k, test_times=10):
#     n = len(data)
#     pass_k = 0
#     for item in data:
#         pp = 0
#         for _ in range(test_times):
#             for i in range(k):
#                 idx = random.randint(0, len(item['output']) - 1)
#                 if item['pass'][idx]:
#                     pp += 1
#                     break
#         pass_k += pp / test_times
#     return pass_k / n

def compute_pass_at(data, k, test_times=10):
    verify_func = math_metric(
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
    )
    n = len(data)
    pass_k = 0
    for item in data:
        ground_truth = item['gound_truth']
        # print(ground_truth)
        pp = 0
        for _ in range(test_times):
            for i in range(k):
                idx = random.randint(0, len(item['output']) - 1)
                try:
                    string_in_last_boxed = verl_math.last_boxed_only_string(item['output'][idx])
                except Exception:
                    print(Exception)
                answer_idx = verl_math.remove_boxed(string_in_last_boxed) if string_in_last_boxed is not None else "0"
                try:
                    score, astr = verify_func([f'\[{ground_truth}\]'], [f'\[{answer_idx}\]'])
                except Exception:
                    print(Exception)
                # print(score, astr)
                if score >= 1:
                    pp += 1
                    break
        pass_k += pp / test_times
    return pass_k / n


def compute_cons_at(data, k):
    """
    Compute cons@k for a dataset.
    
    Args:
        data: list of dicts, each item should have
              - 'output': list of model outputs
              - 'ground_truth': the reference answer
        k: number of samples to consider
    
    Returns:
        cons_k: float, the cons@k score
    """
    verify_func = math_metric(
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
    )

    n = len(data)
    cons_k = 0

    for item in data:
        answer_list = []
        # only take up to k outputs (if outputs < k, take all)
        # max_k = min(k, len(item['output']))
        for i in range(k):
            ground_truth = item['gound_truth']
            string_in_last_boxed = ""
            try:
                string_in_last_boxed = verl_math.last_boxed_only_string(item['output'][i])
            except Exception:
                print(Exception)
            # print(string_in_last_boxed)
            answer_list.append(string_in_last_boxed)

        # 统计出现次数最多的答案
        if len(answer_list) > 0:
            most_common_answer = max(set(answer_list), key=answer_list.count)
            # freq = answer_list.count(most_common_answer)
            most_common_answer = verl_math.remove_boxed(most_common_answer) if most_common_answer is not None else "0"
            # print(most_common_answer)
            try:
                score, _ = verify_func([f'\[{ground_truth}\]'], [f'\[{most_common_answer}\]'])
            except Exception:
                print(Exception)
            cons_k += score
            
    return cons_k / n

        
import argparse

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--file", type=str, default="", required=True)
    args = parser.parse_args()

    file = args.file

    with open(file, "r") as f:
        data = json.load(f)
        print(f"total: {len(data)}")

        # pass_at_1 = compute_pass_at(data, 1)
        # pass_at_10 = compute_pass_at(data, 10)
        # pass_at_100 = compute_pass_at(data, 100)
        # cons_at_100 = compute_cons_at(data, 100)

        # print(f"Pass@1: {pass_at_1: .4f}")
        # print(f"Pass@10: {pass_at_10: .4f}")
        # print(f"Pass@100: {pass_at_100: .4f}")
        # print(f"Cons@100: {cons_at_100: .4f}")
        # print(f"Avg: {(pass_at_1 + pass_at_10 + pass_at_100 + cons_at_100) / 4 : .4f}")

        pass_at_1 = compute_pass_at(data, 1)
        pass_at_10 = compute_pass_at(data, 10)
        pass_at_32 = compute_pass_at(data, 32)
        cons_at_32 = compute_cons_at(data, 32)

        print(f"Pass@1: {pass_at_1: .4f}")
        print(f"Pass@10: {pass_at_10: .4f}")
        print(f"Pass@32: {pass_at_32: .4f}")
        print(f"Cons@32: {cons_at_32: .4f}")
        print(f"Avg: {(pass_at_1 + pass_at_10 + pass_at_32 + cons_at_32) / 4 : .4f}")