from datasets import load_dataset
from tqdm import tqdm
from examples.reward_score import gsm8k, math
import argparse
import numpy as np


def _select_rm_score_fn(reward_function):
    if reward_function == 'gsm8k':
        return gsm8k.compute_score
    elif reward_function == 'math':
        return math.compute_score
    else:
        raise NotImplementedError


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default=None)
    parser.add_argument("--remote_dir", type=str, default=None)
    parser.add_argument("--reward_function", type=str, default="gsm8k")
    parser.add_argument("--n", type=int, default=8)
    return parser.parse_args()


def evaluate(remote_dir, dataset, reward_name, reward_function, n):

    all_scores = []

    for i in tqdm(range(n)):
        scores = []
        for d in range(len(dataset)):
            if reward_name == 'gsm8k':
                cs = reward_function(dataset[d][f'response_{i}'], dataset[d]['reward_model']['ground_truth'], method='flexible')[-1]
            else:
                _, cs = reward_function(dataset[d][f'response_{i}'], dataset[d]['reward_model']['ground_truth'])

            scores.append(cs)
        scores = np.array(scores)
        print(scores.mean())
        all_scores.append(scores)
        dataset = dataset.add_column(f"eval_{i}", scores)

    all_scores = np.array(all_scores)
    all_scores = np.cumsum(all_scores, axis=0) > 0
    print(all_scores.mean(axis=1).tolist())

    # save
    dataset.push_to_hub(remote_dir)


def main():

    # init
    args = parse_arguments()

    # dataset
    dataset = load_dataset(args.dataset, trust_remote_code=True)['train']

    # reward function
    reward_function = _select_rm_score_fn(args.reward_function)

    # evaluate
    evaluate(args.remote_dir, dataset, args.reward_function, reward_function, args.n)

if __name__ == "__main__":
    main()