import ujson as json

from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer


if __name__ == "__main__":
    MODEL_PATH = "Qwen/QwQ-32B"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    for q_idx in tqdm([0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 22, 23, 24, 26]):
        with open(f"qwq32b/cache/qwq32b_aime24_logprobs_{q_idx}.json", "r") as rfile:
            data = json.load(rfile)

        outputs = {
            "question": data["question"],
            "answer": data["answer"],
            "logprobs": []
        }
        answer_tokens = tokenizer.encode(data["answer"], add_special_tokens=False)
        extended_answer_tokens = tokenizer.encode("\n\n**Final Answer**\n\\boxed{" + data["answer"], add_special_tokens=False)
        for rollout in data["logprobs"]:
            logprobs = []
            for item in rollout:
                # if item["prompt_ids"][:-len(extended_answer_tokens)][-1] == tokenizer.encode("</think>", add_special_tokens=False)[0]:
                #     break
                logprobs.append(-np.mean(item["logprobs"][-len(answer_tokens):]))
            logprobs = logprobs[:-len(extended_answer_tokens)]
            outputs["logprobs"].append(logprobs)
        with open(f"qwq32b/cache/qwq32b_aime24_target_logprobs_{q_idx}.json", "w") as wfile:
            json.dump(outputs, wfile)