import os
import numpy as np
import json
from tqdm import tqdm
from openai import OpenAI
from datasets import load_dataset


def main():
    
    client = OpenAI(
        base_url="http://localhost:8000/v1",
        api_key="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",
    )

    def single_process(d):
        # steps = d['steps']
        steps = d["model_output"].split("\n\n")
        messages = []
        scores = []
        for sdx, step in enumerate(steps):
            if sdx == 0:
                messages.append({'role': 'user', 'content': d['problem'] + '\n\n' + step})
            else:
                messages.append({'role': 'user', 'content': step})
            completion = client.chat.completions.create(
                model='Llama3.1-8B-PRM-Mistral-Data',
                messages=messages,
                n=1,
                temperature=0.,
                max_tokens=1,
                logprobs=True,
                top_logprobs=10
            )
            score_dict = {x.token: x.logprob for x in completion.choices[0].logprobs.content[0].top_logprobs}
            plus_logit = score_dict.get('+', -1e10)
            score = np.exp(plus_logit) # assume other tokens than +/- has zero prob
            scores.append(score)

            messages.append({'role': 'assistant', 'content': '+'}) # is this right????
        return scores

    dataset = load_dataset("prometheus-eval/filtered_bon_setting")
    for config in dataset.keys():
        input_data = dataset[config]

        predictions = []
        for d in tqdm(input_data, desc=f'Processing {config}', dynamic_ncols=True):
            predictions.append(single_process(d))
        
        results = []
        for idx, d in enumerate(input_data):
            problem = d['problem']
            ids = d['id']
            generator = d['generator']
            steps = d['steps']
            model_output = d['model_output']
            final_answer_correct = d['final_answer_correct']
            step_scores = predictions[idx]
        
            results.append({
                "id": ids,
                "generator": generator,
                "problem": problem,
                "steps": steps,
                "final_answer_correct": final_answer_correct,
                "model_output": model_output,
                "step_scores": step_scores,
                "score": min(step_scores) # Aggregation: min
            })
        with open(f"prometheus_output_mistral/{config}.jsonl", "w") as file:
            for item in results:
                file.write(json.dumps(item, ensure_ascii=False) + "\n")

if __name__ == '__main__':
    main()