import json
import re
import numpy as np
import transformers
from transformers import AutoTokenizer
import vllm
from itertools import combinations
from tqdm import tqdm
from typing import List
import datasets
import torch
from collections import defaultdict
import random
import math
from transformers import AutoModelForSequenceClassification, AutoTokenizer



def run_experiment(model, tokenizer, ds, inference_func, batch_size, tasks):
    ds_names = list(ds.keys())
    if tasks is None:
        tasks = ds_names
    rewards_store = {}
    for name in tasks:
        print(f"---------- RUNNING INFERENCE: {name} ----------")
        problems = ds[name]["problem"]
        outputs = ds[name]["model_output"]
        rewards = inference_func(problems, outputs, model, tokenizer, batch_size=batch_size)
        rewards_store[name] = rewards
    return rewards_store


def run_inference_skywork(problems, outputs, model, tokenizer, batch_size=16):
    predictions_store = []
    assert len(problems) == len(outputs)
    messages = [[{"role": "user", "content": p}, {"role": "assistant", "content": r}] for p, r in zip(problems, outputs)]
    inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", padding="longest").to(device)
    with torch.no_grad():
        for batch_num in tqdm(range(len(inputs) // batch_size + 1)):
            batched_inputs = inputs[batch_num * batch_size: (batch_num + 1) * batch_size]
            if len(batched_inputs) > 0:
                scores = model(batched_inputs).logits.cpu().reshape(-1).tolist()
                predictions_store.extend(scores)
    return predictions_store


def compute_BoN_acc_simple(rewards_store, ds, lengths_dict, M=8, N_list=[2, 4, 8]):
    bon_acc_store = {}
    for name, rewards in rewards_store.items():
        correctness = ds[name]["final_answer_correct"]
        assert len(rewards) % M == 0
        assert len(correctness) == len(rewards)
        correctness_reshaped = [correctness[i*M: (i + 1)*M] for i in range(len(correctness) // M)]
        rewards_reshaped = [rewards[i*M: (i + 1)*M] for i in range(len(rewards) // M)]
        bon_reward_list = []
        for N in N_list:
            rewards_max_inds = [rew.index(max(rew[:N])) for rew in rewards_reshaped]
            bon_correct = [corr[rewards_max_inds[i]] for i, corr in enumerate(correctness_reshaped)]
            num_correct = sum(bon_correct)
            bon_reward_list.append(round(num_correct / (lengths_dict[name] / M), 3))
        bon_acc_store[name] = bon_reward_list
    return bon_acc_store



if __name__ == "__main__":
    ds = datasets.load_dataset("prometheus-eval/filtered_bon_setting_64")
    ds_no_filter = datasets.load_dataset("prometheus-eval/bon_setting_64")
    lengths_dict = {key: val.num_rows for key, val in ds_no_filter.items()}

    tasks = None # list of tasks. Set tasks = None to run on all tasks.
    M = 64 # number of samples per question in the dataset
    N_list = [2, 4, 8, 16, 32, 64] # BoN list. make sure that max(N_list) <= M
    

    device = "cuda"
    model_name = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
    experiment_name = "skywork_8b_bon_64"
    # model_name = "Skywork/Skywork-Reward-Gemma-2-27B-v0.2"
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map=device,
        attn_implementation="flash_attention_2",
        num_labels=1,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    skywork_rewards = run_experiment(model, tokenizer, ds, run_inference_skywork, batch_size=16, tasks=tasks)

    with open(f"{experiment_name}_BoN_rewards.json", "w") as f:
        json.dump(skywork_rewards, f, indent=4)

    bon_acc_store = compute_BoN_acc_simple(skywork_rewards, ds, lengths_dict, M, N_list)

    with open(f"{experiment_name}_BoN_accs.json", "w") as f:
        json.dump(bon_acc_store, f, indent=4)

    