import json
import random
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser

"""
If we use multiple VLLM processes to accelerate the generation, we need to use this script to merge them.
"""

def process_judgebench(ds):
    domains = ['knowledge', 'reasoning', 'math', 'coding']
    res = {}
    tokens = {}
    for domain in domains:
        ds_subset = ds.filter(lambda x: x['domain']==domain)
        res[domain] = sum(ds_subset['check'])/len(ds_subset)
        tokens[domain] = sum(ds_subset['num_tokens'])/len(ds_subset)

    res['overall'] = sum(ds['check'])/len(ds)
    tokens['average'] = sum(ds['num_tokens'])/len(ds)
    return res,tokens

def process_reward_bench_v2(ds):
    domains = ['Factuality', 'Focus', 'Math', 'Precise IF', 'Safety']
    res = {}
    tokens = {}
    for domain in domains:
        ds_subset = ds.filter(lambda x: x['domain']==domain)
        res[domain] = sum(ds_subset['check'])/len(ds_subset)
        tokens[domain] = sum(ds_subset['num_tokens'])/len(ds_subset)

    res['overall'] = sum(ds['check'])/len(ds)
    tokens['average'] = sum(ds['num_tokens'])/len(ds)
    return res, tokens

def process_chatbot_arena(ds):
    res = {}
    tokens = {}
    res['overall'] = sum(ds['check'])/len(ds)
    tokens['average'] = sum(ds['num_tokens'])/len(ds)
    return res,tokens

def process_mixture(ds):
    res = {}
    tokens = {}
    res['overall'] = sum(ds['check'])/len(ds)
    tokens['average'] = sum(ds['num_tokens'])/len(ds)
    return res,tokens

def process_rm_bench(ds):
    domains = ["chat", "math", "code", "safety"]
    diff_map = {1: "Easy", 2: "Normal", 3: "Hard"}
    res,tokens, diff_scores = {}, {},{name: [] for name in diff_map.values()}

    for domain in domains:
        ds_subset = ds.filter(lambda x: domain in x["domain"])
        res[domain] = sum(ds_subset["check"]) / len(ds_subset)
        tokens[domain] = sum(ds_subset['num_tokens'])/len(ds_subset)
        for d, name in diff_map.items():
            ds_diff = ds_subset.filter(lambda x: x["difficulty"] == d)
            acc = sum(ds_diff["check"]) / len(ds_diff) 
            diff_scores[name].append(acc)

    # average per difficulty
    res.update({k: sum(v) / len(v) for k, v in diff_scores.items() if v})
    # average across domains
    res["average"] = sum(res[d] for d in domains) / len(domains)
    res['overall'] = sum(ds['check'])/len(ds)
    tokens['average']  = sum(ds['num_tokens'])/len(ds)
    return res, tokens


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    res_path: Optional[str] = field(
        default="",
        metadata={"help": "the location of the output file"},
    )


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

path = script_args.res_path 
base_folder = path.rsplit('/', 1)[0]

ds = load_dataset("json", data_files=path,split='train') 
ds = ds.map(lambda row: {"num_tokens": sum(row["num_tokens"])/len(row["num_tokens"])})
ds = ds.map(lambda row: {"check": sum(row["check"])/len(row["check"])})
if 'judgebench' in path:
    res, tokens = process_judgebench(ds)
elif 'rm_bench' in path:
    res, tokens = process_rm_bench(ds)
elif 'chatbot_arena' in path:
    res, tokens = process_chatbot_arena(ds)
elif 'mixture' in path:
    res, tokens = process_mixture(ds)
elif 'reward_bench_v2' in path:
    res, tokens = process_reward_bench_v2(ds)
else:
    raise NotImplementedError


save_summary = [res, tokens]   # put in a list

# Pretty-print JSON string
json_string = json.dumps(save_summary, indent=4)
print(json_string)

# Save to file
with open(base_folder+"/res.json", "w") as f:
    json.dump(save_summary, f, indent=4)




