import json
import re
import numpy as np
import transformers
from transformers import AutoTokenizer
import vllm
from itertools import combinations
from tqdm import tqdm
from typing import List
import datasets
import torch
from collections import defaultdict
import random
import math
from transformers import AutoModelForSequenceClassification, AutoTokenizer



def run_experiment(model, tokenizer, ds, inference_func, batch_size, tasks):
    ds_names = list(ds.keys())
    if tasks is None:
        tasks = ds_names
    rewards_store = {}
    for name in tasks:
        print(f"---------- RUNNING INFERENCE: {name} ----------")
        problems = ds[name]["problem"]
        outputs = ds[name]["model_output"]
        rewards = inference_func(problems, outputs, model, tokenizer, batch_size=batch_size)
        rewards_store[name] = rewards
        
        # Save rewards for each dataset separately
        with open(f"{experiment_name}_{name}_rewards.json", "w") as f:
            json.dump({name: rewards}, f, indent=4)
            
    return rewards_store


def run_inference_skywork(problems, outputs, model, tokenizer, batch_size=16):
    predictions_store = []
    assert len(problems) == len(outputs)
    messages = [[{"role": "user", "content": p}, {"role": "assistant", "content": r}] for p, r in zip(problems, outputs)]
    inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", padding="longest").to(device)
    with torch.no_grad():
        for batch_num in tqdm(range(len(inputs) // batch_size + 1)):
            batched_inputs = inputs[batch_num * batch_size: (batch_num + 1) * batch_size]
            if len(batched_inputs) > 0:
                scores = model(batched_inputs).logits.cpu().reshape(-1).tolist()
                predictions_store.extend(scores)
    return predictions_store


def compute_BoN_acc_simple(rewards_store, ds, lengths_dict, M=8, N_list=[2, 4, 8]):
    bon_acc_store = {}
    for name, rewards in rewards_store.items():
        correctness = ds[name]["final_answer_correct"]
        assert len(rewards) % M == 0
        assert len(correctness) == len(rewards)
        correctness_reshaped = [correctness[i*M: (i + 1)*M] for i in range(len(correctness) // M)]
        rewards_reshaped = [rewards[i*M: (i + 1)*M] for i in range(len(rewards) // M)]
        bon_reward_list = []
        for N in N_list:
            rewards_max_inds = [rew.index(max(rew[:N])) for rew in rewards_reshaped]
            bon_correct = [corr[rewards_max_inds[i]] for i, corr in enumerate(correctness_reshaped)]
            num_correct = sum(bon_correct)
            bon_reward_list.append(round(num_correct / (lengths_dict[name] / M), 3))
        bon_acc_store[name] = bon_reward_list
        
        # Save BoN accuracy for each dataset separately
        with open(f"{experiment_name}_{name}_BoN_accs.json", "w") as f:
            json.dump({name: bon_reward_list}, f, indent=4)
            
    return bon_acc_store



if __name__ == "__main__":
    ds = datasets.load_dataset("prometheus-eval/filtered_bon_setting_64")
    ds_no_filter = datasets.load_dataset("prometheus-eval/bon_setting_64")
    lengths_dict = {key: val.num_rows for key, val in ds_no_filter.items()}

    tasks = [
        "aime_Eurus2_7B_sft", "amc_Eurus2_7B_sft", "gpqa_Eurus2_7B_sft", 
        "leetcode_Eurus2_7B_sft", "math_Eurus2_7B_sft", "minerva_math_Eurus2_7B_sft"
    ]
    M = 64 # number of samples per question in the dataset
    N_list = [2, 4, 8, 16, 32, 64] # BoN list. make sure that max(N_list) <= M
    

    device = "cuda"
    model_name = "Skywork/Skywork-Reward-Gemma-2-27B-v0.2"
    experiment_name = "skywork_27b_bon_64"
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map=device,
        attn_implementation="flash_attention_2",
        num_labels=1,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    skywork_rewards = run_experiment(model, tokenizer, ds, run_inference_skywork, batch_size=6, tasks=tasks)

    # Still save the combined results for convenience
    with open(f"{experiment_name}_BoN_rewards.json", "w") as f:
        json.dump(skywork_rewards, f, indent=4)

    bon_acc_store = compute_BoN_acc_simple(skywork_rewards, ds, lengths_dict, M, N_list)

    # Still save the combined results for convenience
    with open(f"{experiment_name}_BoN_accs.json", "w") as f:
        json.dump(bon_acc_store, f, indent=4)