import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm
import json
import numpy as np
import os

def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
    
    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i] # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
        non_zero_elements_list = positive_probs.cpu().tolist()
        all_scores_res.append(non_zero_elements_list)
    return all_scores_res


# model_name = "Qwen/Qwen2.5-Math-PRM-7B"
# model_name = "Qwen/Qwen2.5-Math-7B-PRM800K"
model_name = "Qwen/Qwen2.5-Math-PRM-72B"

short_model_name = model_name.split("/")[-1]
device = "auto"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(
    model_name, 
    device_map="auto", 
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval()

dataset = load_dataset("prometheus-eval/filtered_bon_setting_64")


for name in dataset.keys():
    correct = []
    error = []
    correct_match = 0
    correct_cnt = 0
    error_match = 0
    error_cnt = 0
    ds = dataset[name]
    results = []
    for d in tqdm(ds):
        problem = d['problem']
        ids = d['id']
        generator = d['generator']
        # steps = d['steps']
        steps = d["model_output"].split("\n\n")
        final_answer_correct = d['final_answer_correct']
        model_output = d['model_output']
        data = {
            "system": "Please reason step by step, and put your final answer within \\boxed{}.",
            "query": problem,
            "response": steps
        }

        messages = [
            {"role": "system", "content": data['system']},
            {"role": "user", "content": data['query']},
            {"role": "assistant", "content": "<extra_0>".join(data['response']) + "<extra_0>"},
        ]
        conversation_str = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=False
        )

        input_ids = tokenizer.encode(
            conversation_str, 
            return_tensors="pt", 
        ).to(model.device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids)

        step_sep_id = tokenizer.encode("<extra_0>")[0]
        token_masks = (input_ids == step_sep_id)
        step_rewards = make_step_rewards(outputs[0], token_masks)[0]

        results.append({
            "id": ids,
            "generator": generator,
            "problem": problem,
            "steps": steps,
            "final_answer_correct": final_answer_correct,
            "model_output": model_output,
            "step_scores": step_rewards,
            "score": min(step_rewards) # Aggregation: min
        })
    
    os.makedirs("prometheus_output_prm72b", exist_ok=True)
    with open(f"prometheus_output_prm72b/{name}.jsonl", "w") as file:
        for item in results:
            file.write(json.dumps(item, ensure_ascii=False) + "\n")