import json
import random

import numpy as np
from transformers import AutoTokenizer
from vllm import SamplingParams, LLM
from nltk import sent_tokenize


random.seed(42)
model = "Qwen/Qwen2.5-32B-Instruct"

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained(model)
    # model = LLM(model, trust_remote_code=True, tensor_parallel_size=2, gpu_memory_utilization=0.7)

    with open("qwen3_32b/cache/qwen3_32b_livemathbench_rollouts.json", "r") as rfile:
        data = json.load(rfile)

    trajectories = []
    correct_lst = []
    num_tokens_lst = []
    for question_idx, question_info in data.items():
        # all_correct_predictions = [item for item in question_info["rollouts"] if item["correct"]]
        # if len(all_correct_predictions) == 0:
        #     continue
        # trajectory = random.sample(all_correct_predictions, k=1)[0]["prediction"]
        # trajectory = trajectory.split("</think>")[0].split("\n\n**Final Answer**")[0]
        # trajectory_sentences = sent_tokenize(trajectory)
        trajectories.append({
            "question": question_info["question"],
            "answer": question_info["answer"],
            "trajectory": [
                tokenizer.decode(tokenizer.encode(item["prediction"].split("</think>")[0].split("\n\n**Final Answer**")[0])[:24576])
                for item in question_info["rollouts"]
            ],
        })
        correct_lst.append([item["correct"] for item in question_info["rollouts"]])
        num_tokens_lst.append([
            len(tokenizer.encode(item["prediction"].split("</think>")[0].split("\n\n**Final Answer**")[0]))
            for item in question_info["rollouts"]
        ])
        assert len(trajectories[-1]["trajectory"]) == 16

    correct_lst = np.array(correct_lst)
    num_tokens_lst = np.array(num_tokens_lst)
    print(f"Avg@16: {correct_lst.mean(1).mean(0)}")
    print(f"Pass@16: {(correct_lst.sum(1) > 0).mean(0)}")
    print(f"avg num_token: {num_tokens_lst.mean(0).mean(0)}")
    exit()

    summarize_prompt = "{trajectory}\n\n\nSummarize the aforementioned reasoning process and not explicitly include the final conclusion and answer. Only provide the English summary."

    messages = [
        [
            {"role": "user", "content": summarize_prompt.format(trajectory=traj)},
        ] for item in trajectories for traj in item["trajectory"]
    ]
    prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
    sampling_params = SamplingParams(n=4, max_tokens=8192, seed=42, temperature=1.0)
    request_outputs = model.generate(prompts, sampling_params, use_tqdm=True) 
    responses = [output.text for resquest_output in request_outputs for output in resquest_output.outputs]

    summarized_trajectories = {
        str(i): {
            "question": item["question"],
            "answer": item["answer"],
            "trajectory": item["trajectory"],
            "summarized_trajectory": responses[i * 16 * 4: (i + 1) * 16 * 4],
        }
        for i, item in enumerate(trajectories)
    }
    with open(f"qwen3_32b/cache/aime24_summarized_trajectories/summarized_trajectories_x4.json", "w") as wfile:
        json.dump(summarized_trajectories, wfile, indent=4, ensure_ascii=False)