import asyncio
import json
import os
import argparse
from verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion
from verl.utils.reward_score.prime_code import apps_check_correctness
from verl.workers.reward_manager.prime import parallel_compute_score_async, run_reward_scoring
from utils import save_dataset
import datasets


def load_jsonl(jsonl_path):
    return [json.loads(row) for row in open(jsonl_path, "r").readlines()]


def load_single_dataset(dataset_path: str, dataset_split: str = None) -> datasets.Dataset:
    # load from file
    if os.path.isfile(dataset_path):
        if dataset_path.endswith("jsonl"):
            dataset = load_jsonl(dataset_path)
        elif dataset_path.endswith("json"):
            dataset = json.load(open(dataset_path, "r"))
        elif dataset_path.endswith("parquet"):
            dataset = datasets.load_dataset('parquet', data_files=dataset_path)
        else: 
            raise RuntimeError(f"No support file type for {dataset_path.split('.')[-1]}")
        if isinstance(dataset, list):
            dataset = datasets.Dataset.from_list(dataset)
    
    # load from fold
    else:
        try:
            return datasets.load_dataset(dataset_path, split=dataset_split)
        except ValueError:
            dataset = datasets.load_from_disk(dataset_path)
    
    # dataset split
    if dataset_split is not None and isinstance(dataset, datasets.DatasetDict):
        dataset = dataset[dataset_split]
    return dataset
        

def prepare_eval_row(row):
    return {
        "sequences_str": row["response"],
        "ground_truth": row["reward_model"]["ground_truth"],
        "data_sources": row["data_source"]
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--data', type=str, required=True)
    parser.add_argument("--dataset_split", type=str, required=False, default="train")
    parser.add_argument("--save_file", type=str, required=False)
    args = parser.parse_args()
    test_set = load_single_dataset(args.data, dataset_split=args.dataset_split)
    test_set_split = test_set.map(prepare_eval_row)
    scores = run_reward_scoring(
        default_compute_score,
        test_set_split["sequences_str"], 
        test_set_split["ground_truth"],
        test_set_split["data_sources"],
        num_processes=64
        )
    
    score_dict = {}
    count_dict = {}
    for task, score in zip(test_set_split["data_sources"], scores):
        count_dict[task] = count_dict.get(task, 0) + 1
        if score > 0.5:
            score_dict[task] = score_dict.get(task, 0) + 1
    print(score_dict)
    print(count_dict)
    for task, cnt in count_dict.items():
        print(task, round(score_dict.get(task, 0) / cnt * 100, 2), "%")
    print("all", sum(score_dict.values()) / sum(count_dict.values()))

    if args.save_file is not None:
        test_set = test_set.add_column("score", scores)
        save_dataset(test_set, args.save_file)
        
