"""Give rewards for various tasks, including SQL, BIRD, GSM8K and Math."""

import argparse
from tqdm import tqdm

from src.data_construct.verifier import SQLVerifier, BIRDVerifier, GSM8KMathVerifier, MathVerifier, CodeVerifier
from src.utils.utils import load_json, write_json

TASK2API = {
    'sql': SQLVerifier,
    'bird': BIRDVerifier,
    'gsm8k': GSM8KMathVerifier,
    'math': MathVerifier,
    'code': CodeVerifier
}

def rewarding():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--output_path', type=str)
    parser.add_argument('--gold_key', type=str, default='output')
    parser.add_argument('--task', type=str, default='sql', choices=['sql', 'bird', 'gsm8k', 'math', 'code'])
    parser.add_argument('--config', type=str, default=None)
    args = parser.parse_args()

    # load data
    reward_data = load_json(args.data_path) # [{'gold_key', 'prompt', 'response': [xx, xx]}]

    # load config
    if args.config:
        config = load_json(args.config)
    else:
        config = None

    verifer = TASK2API[args.task](config, pbar=False)
    for i, instance in tqdm(enumerate(reward_data), total=len(reward_data), desc='Rewarding'):
        ground_truth = instance[args.gold_key]
        responses = list(set(instance['response'])) # dedup
        if hasattr(verifer, 'parallel_execute_and_verify'):
            exec_responses = [(idx, ground_truth, responses[idx]) for idx in range(len(responses))]
            rewards = verifer.parallel_execute_and_verify(exec_responses)
        else:
            # just for our SQL
            gold_reward = verifer.execute(ground_truth)
            # if ground truth is incorrect, skip
            if not gold_reward['exec_status']:
                reward_data[i]['response'] = []
                reward_data[i]['rewards'] = []
                continue
            exec_responses = list(map(verifer.execute, responses))
            rewards = [verifer.verify(gold_reward, res) for res in exec_responses]
        reward_data[i]['response'] = responses
        reward_data[i]['rewards'] = rewards

    write_json(args.output_path, reward_data)

if __name__ == '__main__':
    rewarding()