from utils.reward import reward_factory
import argparse
from utils.load_data import load_json_data, write_json_data
from transformers import set_seed
from tqdm import tqdm
import random 
import os 
random.seed(17)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='Llama3_1_8b_chat')
    parser.add_argument('--n_samples', type=int, default=200)
    parser.add_argument('--n_examples', type=int, default=3)
    parser.add_argument('--dataset', type=str, default='math')
    parser.add_argument('--method', type=str, default='sc')
    parser.add_argument('--roll_num', type=int, default=5)
    parser.add_argument('--reward', type=str, default=None)
    parser.add_argument('--remote', action='store_true')
    parser.add_argument('--golden', action='store_true')
    args = parser.parse_args()
    set_seed(17)
    
    model_name = args.model
    n_samples = args.n_samples
    n_examples = args.n_examples
    dataset = args.dataset 
    method = args.method
    roll_num = args.roll_num
    reward = args.reward
    remote = args.remote
    golden = args.golden

    if model_name == 'Llama3_1_8b_chat':
        sc_num = 100
    else:
        sc_num = 128

    if method == 'sc':
        result_path =  f'./result/{dataset}/{model_name}/sc{sc_num}_e{n_examples}_{n_samples}.json'
    else:   
        result_path = f'./result/{dataset}/{model_name}/{method}_e{n_examples}_{n_samples}.json'
        
    result_data = load_json_data(result_path)[:-1]
    reward_model = reward_factory(reward, remote, dataset)
    reward_path =  f'./result/{dataset}/{model_name}/step_reward_{reward}_{roll_num}_{n_samples}.json'
    reward_results = []
    if golden:
        for item in tqdm(result_data):
            step_scores = reward_model.score(question=item['question'], responses=[item['reason']])
            reward_results.append({
                                    'id':item['id'],
                                    'question':item['question'],
                                    'response':item['reason'],
                                    'step_scores_golden':step_scores
                                   })
        reward_path = result_path.replace('.json','') + f'_{reward}_golden.json'
    else:
        for item in tqdm(result_data):
            # correct_idx = [idx for idx in range(len(item['corrects'])) if item['corrects'][idx]]
            # if len(correct_idx) == len(item['corrects']) or not correct_idx:
            #     continue
            # correct_res = random.choice([item['response'][i] for i in correct_idx])
            # wrong_res = random.choice([item['response'][i] for i in range(len(item['response'])) if i not in correct_idx and item['answer'][i]])
            # responses = [correct_res, wrong_res]
            if not any(item['corrects']):
                continue
            index = [idx for idx in range(sc_num) if item['corrects'][idx]][:roll_num]
            responses = [item['response'][idx] for idx in index]
            answers = [item['answer'][idx] for idx in index]
            if isinstance(responses[0], dict):
                outputs = [res['content'] for res in responses]
            else:
                outputs = responses
            try:
                step_score = reward_model.score(question=item['question'], responses=outputs)
                reward_results.append({
                                    'id':item['id'],
                                    'question':item['question'],
                                    'idx':index,
                                    'response':responses,
                                    'step_score':step_score
                                   })
            except Exception as e:
                write_json_data(reward_path, reward_results)
    write_json_data(reward_path, reward_results)
    