import torch
from transformers import AutoTokenizer
from datasets import load_dataset
import json
from collections import defaultdict
import math

KEYS = [
    "math_Eurus2_7B_sft", "math_Meta_Llama_3.1_70B_Instruct", "math_Qwen2.5_7B_Instruct",
    "amc_Eurus2_7B_sft", "amc_Meta_Llama_3.1_70B_Instruct", "amc_Qwen2.5_7B_Instruct",
    "aime_Eurus2_7B_sft", "aime_Meta_Llama_3.1_70B_Instruct", "aime_Qwen2.5_7B_Instruct",
    "olympiadbench_Eurus2_7B_sft", "olympiadbench_Meta_Llama_3.1_70B_Instruct", "olympiadbench_Qwen2.5_7B_Instruct",
    "minerva_math_Eurus2_7B_sft", "minerva_math_Meta_Llama_3.1_70B_Instruct", "minerva_math_Qwen2.5_7B_Instruct",
    "leetcode_Eurus2_7B_sft", "leetcode_Meta_Llama_3.1_70B_Instruct", "leetcode_Qwen2.5_7B_Instruct",
    "gpqa_Eurus2_7B_sft", "gpqa_Meta_Llama_3.1_70B_Instruct", "gpqa_Qwen2.5_7B_Instruct"
]
full_ds = load_dataset("prometheus-eval/bon_setting_64")

PATH = ""
log2N = 6

def sigmoid(x):
    return 1 / (1 + math.exp(-x))

def relu(x):
    return max(0, x)

def compute_score(logprobs, strategy):
    if not logprobs:
        return float('-inf')
    if strategy == 'min':
        return min(logprobs)
    elif strategy == 'max':
        return max(logprobs)
    elif strategy == 'prod':
        return sum(math.log(lp) for lp in logprobs) / len(logprobs)
    elif strategy == 'mean':
        return sum(logprobs) / len(logprobs)
    elif strategy == 'mean_logit':
        logits = [math.log(lp / (1 - lp)) if lp < 1 else float('inf') for lp in logprobs]
        return sigmoid(sum(logits) / len(logits))
    elif strategy == 'mean_odd':
        odds = [lp / (1 - lp) for lp in logprobs]
        return relu(sum(odds) / len(odds))
    elif strategy == 'last':
        return logprobs[-1]

# Track statistics for each key
key_scores = [[] for _ in range(log2N+1)]

for key in KEYS:
    print()
    print(key)
    ds = full_ds[key]
    cnts = [0 for _ in range(log2N+1)] # correct cnt
    with open(f"{PATH}/{key}.jsonl") as file:
        id_to_data = defaultdict(list)
        for line in file:
            data = json.loads(line)
            problem_id = "".join(data["id"].split("/")[:-1]) # remove the last part (same problem, response idx)
            id_to_data[problem_id].append(data)
        total_problems = len(ds)//(2**log2N)
        print(len(id_to_data), total_problems)
        for problem_id, data in id_to_data.items():
            # find element with max 'score'
            for N in range(log2N+1):
                BoN = 2 ** N
                max_score = max(data[:BoN], key=lambda x: compute_score(x["step_scores"], "min"))
                if max_score["final_answer_correct"]:
                    cnts[N] += 1
        for N in range(log2N+1):
            score = cnts[N]/total_problems*100
            print(2 ** N, score)
            key_scores[N].append(score)

print("\nOVERALL AVERAGE:")
for N in range(log2N+1):
    avg_score = sum(key_scores[N])/len(KEYS)
    print(2 ** N, avg_score)