import json
import argparse
from verl.utils.reward_score.qa_em import extract_solution, em_check
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import torch


def parse():
    parser = argparse.ArgumentParser()

    parser.add_argument("--input-file", type=str)
    parser.add_argument("--max-n", type=int)
    parser.add_argument("--n-list", type=int, nargs="+")

    return parser.parse_args()


def run():
    args = parse()

    messages = []
    answers = []
    tprm_scores = []
    # with open(args.input_file, "r") as reader:
    #     for line in reader:
    #         example = json.loads(line)
    #         messages.append(example["messages"])
    #         answers.append(example["answers"])
    #         tprm_scores.append(example["tprm_score"])

    data = torch.load(args.input_file, weights_only=False, map_location="cpu")[
        "non_tensor_batch"
    ]

    for index in range(len(data["messages"])):
        messages.append(data["messages"][index])
        tprm_scores.append(data["tprm_scores"][index])
        answers.append(data["reward_model"][index]["ground_truth"]["target"])

    all_majority_list, all_pass_n_list, all_tprm_list = [], [], []
    for i in range(0, len(data["messages"]), args.max_n):
        majority_list = []
        pass_n_list = []
        tprm_list = []
        for n in args.n_list:
            preds = [
                extract_solution(messages[j][-1]["content"]) for j in range(i, i + n)
            ]
            counter = Counter(preds)
            majority_pred = counter.most_common(1)[0][0]
            if majority_pred is None:
                majority_list.append(0)
            else:
                majority_list.append(em_check(majority_pred, answers[i]))

            passed = 0
            for pred, answer in zip(preds, answers[i : i + n]):
                if pred is not None and em_check(pred, answer):
                    passed = 1
                    break
            pass_n_list.append(passed)

            # tprm
            candidate = i
            max_tprm_score = -1
            for j in range(i, i + n):
                # tprm_score = np.array(tprm_scores[j]).prod()
                tprm_score = np.exp(np.mean(np.log(np.array(tprm_scores[j]))))

                if tprm_score > max_tprm_score:
                    max_tprm_score = tprm_score
                    candidate = j
            tprm_pred = preds[candidate - i]

            if tprm_pred is None:
                tprm_list.append(0)
            else:
                tprm_list.append(em_check(tprm_pred, answers[i]))

        # sc = sum(majority_list) / len(majority_list)
        # pass_n = sum(pass_n_list) / len(pass_n_list)
        # tprm = sum(tprm_list) / len(tprm_list)

        all_majority_list.append(majority_list)
        all_pass_n_list.append(pass_n_list)
        all_tprm_list.append(tprm_list)

    all_majority_list = list(zip(*all_majority_list))
    all_pass_n_list = list(zip(*all_pass_n_list))
    all_tprm_list = list(zip(*all_tprm_list))
    for index, n in enumerate(args.n_list):
        sc = sum(all_majority_list[index]) / len(all_majority_list[index]) * 100
        pass_n = sum(all_pass_n_list[index]) / len(all_pass_n_list[index]) * 100
        tprm_acc = sum(all_tprm_list[index]) / len(all_tprm_list[index]) * 100
        print(f"sc: {sc}, pass_n: {pass_n}, tprm_acc: {tprm_acc}")

    # 

if __name__ == "__main__":
    run()
