from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import torch
from datasets import load_dataset
from tqdm import tqdm
import json

good_token = '+'
bad_token = '-'
step_tag = 'ки'

tokenizer = AutoTokenizer.from_pretrained('peiyi9979/math-shepherd-mistral-7b-prm')
candidate_tokens = tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
step_tag_id = tokenizer.encode(f"{step_tag}")[-1] # 12902
model = AutoModelForCausalLM.from_pretrained('peiyi9979/math-shepherd-mistral-7b-prm').eval().to('cuda')
dataset = load_dataset("prometheus-eval/filtered_bon_setting")

for name in dataset.keys():
    correct = []
    error = []
    correct_match = 0
    correct_cnt = 0
    error_match = 0
    error_cnt = 0
    ds = dataset[name]
    results = []
    for d in tqdm(ds):
        problem = d['problem']
        ids = d['id']
        generator = d['generator']
        # steps = d['steps']
        steps = d["model_output"].split("\n\n")
        model_output = d['model_output']
        final_answer_correct = d['final_answer_correct']
        # label = d['label'] - processbench only
        output = ''
        for step in steps:
            output += f"{step} {step_tag}\n"

        input_for_prm = f"{problem} {output}"
        input_id = torch.tensor([tokenizer.encode(input_for_prm)]).to('cuda')
        with torch.no_grad():
            logits = model(input_id).logits[:,:,candidate_tokens]
            scores = logits.softmax(dim=-1)[:,:,0] 
            step_scores = scores[input_id == step_tag_id]
            if len(steps) != step_scores.size(0):
                print(len(steps), step_scores.size(0))
                input()
            step_scores = step_scores.tolist()
        results.append({
            "id": ids,
            "generator": generator,
            "problem": problem,
            "steps": steps,
            "final_answer_correct": final_answer_correct,
            "model_output": model_output,
            "step_scores": step_scores,
            "score": min(step_scores) # Aggregation: min
        })
    
    with open(f"prometheus_output/{name}.jsonl", "w") as file:
        for item in results:
            file.write(json.dumps(item, ensure_ascii=False) + "\n")