import torch
from transformers import AutoModel, AutoTokenizer

model_name = "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-PRM-72B/snapshots/9df429b02adb5f764cd6e30e76a0cca16d501ae1/"
device = "auto" # the device to load the model onto

model = AutoModel.from_pretrained(
    model_name, 
    device_map=device, 
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval()

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
import pandas as pd

df = pd.read_csv("./Results/skills_2/MATH500/self_con_seed21_budget8_acc62.4_models['tinyQwen']_v2.csv")

import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import ast

def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels

    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i]  # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # valid_tokens, num_labels
        all_scores_res.append(positive_probs.cpu().tolist())
    return all_scores_res

# assume model and tokenizer are already loaded & on the correct device
all_rewards = []

for idx, question in enumerate(df['question']):
    preds = ast.literal_eval(df['all_predictions'][idx])  # list of 4 generated strings
    row_rewards = []

    for response in preds:
        # prepare the chat prompt
        messages = [
            {"role": "system", "content": "Provide your step-by-step reasoning first, "
                                          "and then print \"The answer is \\boxed{{X}}\", "
                                          "where X is the final answer, at the end of your response."},
            {"role": "user",   "content": question},
            {"role": "assistant", "content": response + "<extra_0>"}
        ]
        conversation_str = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
        input_ids = tokenizer.encode(
            conversation_str,
            return_tensors="pt"
        ).to(model.device)

        # run model
        with torch.no_grad():
            outputs = model(input_ids=input_ids)

        # mask out everything except the <extra_0> separator tokens
        step_sep_id = tokenizer.encode("<extra_0>")[0]
        token_masks = (input_ids == step_sep_id)

        # compute and collect the reward sequence for this prediction
        reward_list = make_step_rewards(outputs.logits, token_masks)[0]
        row_rewards.append(reward_list)

    all_rewards.append(row_rewards)

# all_rewards[i] is a list of 4 reward‐lists for question i
# print(all_rewards)

from math_verify import parse, verify
from math import prod

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    # compute for each observed answer
    # for x in set(answers):
    #     # confidences for x
    #     confs_x    = [c for a, c in zip(answers, confidences) if a == x]
    #     # confidences not for x
    #     others     = [c for a, c in zip(answers, confidences) if a != x]

    #     prod_correct = prod(confs_x) if confs_x else 1.0
    #     prod_penalty = prod((1 - c) / U for c in others) if others else 1.0

    #     scores[x] = prod_correct * prod_penalty

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0

    # print(scores)
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores

all_rewards_v1 = [[val[0] for val in _val] for _val in all_rewards]
all_final_ans = [ast.literal_eval(df['all_answers'][idx]) for idx in range(len(df))]

import numpy as np

choices_2_scores = [compute_answer_scores(final_ans[:1], reward_score[:1]) for final_ans, reward_score in zip(all_final_ans, all_rewards_v1)]
indices = np.arange(len(df))

num_llm_calls = len(indices)

for i in range(len(all_rewards_v1[0]) - 1):
    mask = [max([val for val in list(choices_2_scores[i].values())]) < 0.999 for i in range(len(choices_2_scores))]
    indices = np.nonzero(np.array(mask))[0]
    if len(indices) == 0:
        break
    num_llm_calls += len(indices)

    for index in indices:
        choices_2_scores[index] = compute_answer_scores(all_final_ans[index][:i+2], all_rewards_v1[index][:i+2])

preds = []
for answer_scores in choices_2_scores:
    # not_none_choices = [(answer_scores[key][0], answer_scores[key][1]) for key in answer_scores.keys() if answer_scores[key][0] is not None]
    not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
    values = [val[1] for val in not_none_choices]
    preds.append(not_none_choices[np.argmax(values)][0])

num_llm_calls / len(df)

correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
acc = round(sum(correctness) / len(correctness) * 100, 2)
acc