from argparse import ArgumentParser
import random

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import json
import re
import os
from tqdm import tqdm

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--ii", type=str)
    args = parser.parse_args()
    ii = str(args.ii)
    
    MODEL_PATH = "Qwen/QwQ-32B"

    model = LLM(MODEL_PATH, tensor_parallel_size=2, gpu_memory_utilization=0.7)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    with open("qwq32b_aime24_rollouts.json", "r") as rfile:
        rollouts = json.load(rfile)

    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    for question_idx, question_info in rollouts.items():
        if question_idx != ii:
            continue
        messages = [
            {"role": "user", "content": prompt_template.format(question=question_info["question"])}
        ]
        input_prefix = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        answer_tokens = tokenizer.encode("\n\n**Final Answer**\n\\boxed{" + question_info["answer"], add_special_tokens=False)
        prompt2logprobs = {
            "question": question_info["question"], 
            "answer": question_info["answer"],
            'logprobs': []
        }
        for rollout in random.sample(list(filter(lambda x: x["correct"], question_info["rollouts"])), k=4):
            input_and_response = input_prefix + rollout["prediction"]
            tokens = tokenizer.encode(input_and_response, add_special_tokens=False)
            think_token_ids = tokenizer.encode("<think>", add_special_tokens=False)

            def find_subsequence(tokens, sub):
                for i in range(len(tokens) - len(sub) + 1):
                    if tokens[i:i+len(sub)] == sub:
                        return i
                return -1
            think_pos = find_subsequence(tokens, think_token_ids)
            assert think_pos != -1

            prefixes = []
            for end in range(think_pos + 1, len(tokens) + 1):
                prefixes.append(tokens[:end] + answer_tokens)

            sampling_params = SamplingParams(
                temperature=0, prompt_logprobs=1, max_tokens=1
            )
            outputs = model.generate([
                {"prompt_token_ids": prefix} for prefix in prefixes
            ], sampling_params, use_tqdm=True)
            prompt2logprobs["logprobs"].append([])
            for i, (output, prefix) in enumerate(zip(outputs, prefixes)):
                def extract_logprobs(logprobs, tokens):
                    probs = []
                    for item, tok_id in zip(logprobs, tokens):
                        probs.append(item[tok_id].logprob) if item is not None else probs.append(None)
                    return probs
                prompt2logprobs["logprobs"][-1].append({
                    "prompt_ids": prefix,
                    "logprobs": extract_logprobs(output.prompt_logprobs, prefix)
                })

        with open(f"qwq32b_aime24_logprobs_{ii}.json", "w") as wfile:
            json.dump(prompt2logprobs, wfile, indent=4)
            