from utils.load_data import load_json_data, extract_answer, write_json_data
from utils.eval import is_equiv
import os 
# import os
# import numpy as np 
# import pandas as pd
# import json
# import random
# from tqdm import tqdm
# from collections import Counter
# from utils.metrics import draw_box, draw_line, draw_bar, draw_heat
# def list_all_files(directory):
#     file_paths = []
#     for root, _, files in os.walk(directory):
#         for file in files:
#             file_paths.append(os.path.join(root, file))
#     return file_paths




# def extract_mcts_answer(path, select='sc'):
    
#     def travel_traces(traces, num, target):
#         tree = traces[str(num)]
#         node = tree['0']
#         while node['children']:
#             children = [tree[str(child)] for child in node['children']]
#             node = max(children, key=lambda x:x[target])
#         return node["answer"]

#     cor = 0
#     cnt = 0
#     data = load_json_data(path)
#     for item in tqdm(data):
#         if 'acc' in item.keys():
#             item['acc'] = cor / cnt
#             break 
#         response = [tup for tup in item['response']]    
#         answer = [extract_answer(item['content'], dataset='math') for item in response]
#         if select == 'sc':
#             pred = max(answer, key=answer.count)
#         elif select == 'q_value':
#             solution =  max(response, key=lambda x: x['q_value'])['content']
#             pred = extract_answer(solution, 'math')
#         elif select == 'reward':
#             solution = max(response, key=lambda x: x['reward'])['content']
#             pred = extract_answer(solution, 'math')
#         elif select == 'reward_sc':
#             coef = {}
#             for i in range(len(answer)):
#                 if answer[i] not in coef.keys():
#                     coef[answer[i]] = response[i]['reward']
#                 else:
#                     coef[answer[i]] += response[i]['reward']
#             pred = max(coef, key=lambda x: coef[x])
#         else:
#             if select == 'bestN':
#                 solution = [travel_traces(item['trace'], i, 'N') for i in range(len(response))]
#             else:
#                 solution = [travel_traces(item['trace'], i, 'Q') for i in range(len(response))]
#             answer = [extract_answer(content, 'math') for content in solution]
#             pred = answer[-1]
#             item['answer'] = answer
#             item['corrects'] = [is_equiv(item['label'], ans, 'math') for ans in answer]
        
#         item['pred'] = pred 
#         if is_equiv(item['label'], pred, 'math'):
#             item['cor_flag'] = True
#             cor += 1
#         else:
#             item['cor_flag'] = False
#         del item['trace']
#         cnt += 1
#     path = path.replace('reward', select)
#     write_json_data(path, data)
    



# # extract_mcts_answer()
# # 示例用法
# # model_name = 'Qwen2_5_3b_chat'
# # dataset = 'math'
# # result_dir =  f'./result/{dataset}/{model_name}/sc128_e3_500.json'
# # path = '/mnt/userdata/ljc/code/o1_cot/result/math/Qwen2_5_3b_chat/best128_skywork_e3_500.json'


# # result_dir =  f'./result/{dataset}/'
# # all_paths = list_all_files(result_dir)
# # for path in all_paths:
# #     if 'eval' in path:
# #         continue
# #     data = load_json_data(path)
# # data = load_json_data(path)
# # cor = 0
# # cnt = 0

    
# # acc = 0
# # cnt = 0

# # data = load_json_data(result_dir)

# # cnt = 0
# # for item in data[:-1]:
# #     correct_idx = [idx for idx in range(128) if item['corrects'][idx]]
# #     if not correct_idx or len(correct_idx) > 3:
# #         continue
# #     print(item['id'])
# #     cnt += 1
# #     for idx in range(128):
# #         if idx in correct_idx:
# #             print(item['question'])
# #             print(item['response'][idx])
            
# # print(cnt)


# # model_name = 'Qwen2_5_3b_chat'
# # dataset = 'math'
# # # result_dir =  f'./result/{dataset}/{model_name}/sc128_e3_500.json'
# # # path = '/mnt/userdata/ljc/code/o1_cot/result/math/Qwen2_5_3b_chat/beam_e3_200.json'
# # # def create()

# # result_dir =  f'./result/math/Qwen2_5_3b_chat'
# # all_paths = list_all_files(result_dir)
# # for path in all_paths:
# #     if 'mcts' not in path or 'reward' not in path:
# #         continue
# #     extract_mcts_answer(path, select='q_value')
# #     extract_mcts_answer(path, select='sc')
# #     # data = load_json_data(path)[:200]
# #     # acc = [item['cor_flag'] for item in data].count(True) / 200
# #     # data.append({'acc':acc})
# #     # path = path.replace('e3_500', 'e3_200')
# #     # write_json_data(path, data)



# if __name__ == '__main__':
#     # model = 'Llama3_1_8b_chat'
#     # model = 'Qwen2_5_3b_chat'
#     # dataset = 'math'
#     # if model == 'Llama3_1_8b_chat':
#     #     roll_num = 100
#     #     n_samples = 200
#     # else:
#     #     roll_num = 128
#     #     n_samples = 500
#     # # beam_path = f'./result/math/Qwen2_5_3b_chat/beam_e3_200.json'
#     # bestn1_path = f'./result/{dataset}/{model}/best{roll_num}_skywork_e3_{n_samples}.json'
#     # bestn2_path = f'./result/{dataset}/{model}/best{roll_num}_shepherd_e3_{n_samples}.json'
#     # bestn3_path = f'./result/{dataset}/{model}/best{roll_num}_skyworko1_e3_{n_samples}.json'
#     # bestn4_path = f'./result/{dataset}/{model}/best{roll_num}_armorm_e3_{n_samples}.json'
#     # # mcts_path = f'./result/{dataset}/Qwen2_5_3b_chat/mcts_e3_200.json'
#     # sc_path = f'./result/{dataset}/{model}/sc{roll_num}_e3_{n_samples}.json'
#     # # slm_path = f'./result/math/{model}/best{roll_num}_{model}_e3_{n_samples}.json'
#     # reward_process_path = f'./result/math/Qwen2_5_3b_chat/sc128_e3_500_skyworko1.json'
#     # golden_reward_process_path = f'./result/math/Qwen2_5_3b_chat/sc128_e3_500_skyworko1_golden.json'

#     # # beam_data = load_json_data(beam_path)[:-1]
#     # bestn1_data = load_json_data(bestn1_path)[:-1]
#     # bestn2_data = load_json_data(bestn2_path)[:-1]
#     # bestn3_data = load_json_data(bestn3_path)[:-1]
#     # bestn4_data = load_json_data(bestn4_path)[:-1]
#     # # mcts_data = load_json_data(mcts_path)[:-1]
#     # sc_data = load_json_data(sc_path)[:-1]
#     # # slm_data = load_json_data(slm_path)[:-1]
#     # reward_process_data = load_json_data(reward_process_path)
#     # golden_reward_process_data = load_json_data(golden_reward_process_path)

        
#     # data_dic = {'sc':sc_data,'skywork':bestn1_data, 'skyworko1':bestn3_data[:100], 'armnorm':bestn4_data[:100]}
#     # # data_ls = [bestn1_data, sc_data]

#     # cor = 0
#     # method_ls = []
#     # score_ls = []
#     # for item in reward_process_data:
#     #     if not any(item['corrects']):
#     #         continue
#     #     scores = 
#     #     if isinstance(item['response'][0], dict):
#     #         scores = [item['response'][idx]['score'] for idx in index]
#     #         best_idx = np.argmax(np.array(scores))    
#     #         response = [item['response'][idx]['content'] for idx in index][best_idx]
#     #     else:
#     #         answer = [item['answer'][idx] for idx in index]
#     #         best_answer = max(answer, key=answer.count)
#     #         best_idx = answer.index(best_answer)
#     #         response = [item['response'][idx] for idx in index][best_idx]
#     #     # score = eval(reward[best_idx]) / len(response.split('.'))
#     #     cor_flag = cor_flags[best_idx]
#     #     if not cor_flag:
#     #         print(f'Method:\n{name}')
#     #         print(f'Question:\n{item["question"]}')
#     #         print(f'Response:\n{response}')
#     #         print(f'Golden:\n{item["reason"]}')
#     #         method_ls.append(name)
#     #             # score_ls.append(score)
#     # data = {'method':method_ls, 'score':score_ls}
#     # data = pd.DataFrame(data, columns=[ 'method', 'score'])
#     # path = 'fig/process.png'
#     # draw_box(data, path)
#     #     # answer = [data[i]['pred'] for data in data_ls]
#     #     corrects = [data[i]['cor_flag'] for data in data_ls]
#     #     # for pred in answer:
#     #     #     if is_equiv(data_ls[0][i]['label'], pred, 'math'):
#     #             # cor += 1
#     #             # break
#     #     if True in corrects:
#     #         cor += 1
#     #     # pred = max(answer, key=answer.count)
#     #     # if is_equiv(data_ls[0][i]['label'], pred, 'math'):
#     #         # cor += 1
#     # print(cor/200)
#     random.seed(17)
#     acc = []
#     nums = []
#     rm = []
# # for name, data in data_dic.items():
# #     cor = 0
# #     cnt = 0
# #     for k in range(2, roll_num):
# #         index = range(k)
# #         for item in data:
# #             if isinstance(item['response'][0], dict):
# #                 scores = [item['response'][i]['score'] for i in index]
                
# #             corrects = [item['corrects'][i] for i in index]
# #             if True not in corrects or False not in corrects:
# #                 continue
# #             true_score = [scores[i] for i in range(k) if corrects[i]][0]
# #             false_score = [scores[i] for i in range(k) if not corrects[i]][0]
# #             if true_score > false_score:
# #                 cor += 1
# #             cnt += 1
# #         acc.append(cor/cnt)
# #         nums.append(k)
# #         rm.append(name)

# # data = {'roll_num':nums, 'score':acc, 'rm':rm}
# # data = pd.DataFrame(data, columns=['roll_num', 'score', 'rm'])
# # path = f'fig/{model}_test.png'
# # draw_line(data, path)

# # random.seed(17)
# # acc = []
# # nums = []
# # rm = []
# # for name, data in data_dic.items():
# #     old_corrects = [item['corrects'][0] for item in data]
# #     for k in range(2, roll_num, 8):
# #         corrects = []
# #         for item in data:
# #             if isinstance(item['response'][0], dict):
# #                 scores = [item['response'][i]['score'] for i in range(k)]
# #                 idx = np.argmax(np.array(scores)) 
# #                 cor_flags = item['corrects'][:k]
# #                 corrects.append(cor_flags[idx])
# #             else:
# #                 answer = item['answer'][:k]
# #                 pred = max(answer, key=answer.count)
# #                 corrects.append(is_equiv(pred, item['label'], dataset))
# #         cor_cnt = [1 if corrects[i] and not old_corrects[i] else 0 for i in range(n_samples)].count(1)
# #         # acc.append(cor_cnt)
# #         # nums.append(k)
# #         # rm.append(f'good_{name}')
# #         cnt = [1 if not corrects[i] and old_corrects[i] else 0 for i in range(n_samples)].count(1)
# #         score = cor_cnt - cnt
# #         # if abs(score) < 3:
# #         #     score = 0
# #         acc.append(score)
# #         nums.append(k)
# #         rm.append(f'pure_{name}')
# #         old_corrects = corrects
# # old_corrects = [item['corrects'][0] for item in data]
# # for k in range(2, roll_num, 8):
# #     corrects = [any(item['corrects'][:k]) for item in data]
# #     cnt = [1 if corrects[i] and not old_corrects[i] else 0 for i in range(n_samples)].count(1)
# #     # if abs(cnt) < 3:
# #     #     cnt = 0
# #     acc.append(cnt)
# #     nums.append(k)
# #     rm.append(f'oracle')
# #     old_corrects = corrects

# # data = {'roll_num':nums, 'score':acc, 'rm':rm}
# # data = pd.DataFrame(data, columns=['roll_num', 'score', 'rm'])
# # path = f'fig/{model}_roll_num.pdf'
# # draw_bar(data, path)
# roll_num = 16
# labels = []
# steps = []
# scores = []

# def split_list_evenly(lst, k):
#     n = len(lst)
#     # 如果列表为空或k为0，返回空列表
#     if n == 0 or k == 0:
#         return [[] for _ in range(k)]
    
#     # 计算每部分的大小
#     avg_size = n // k
#     extra = n % k  # 多出来的元素数

#     result = []
#     start = 0

#     for i in range(k):
#         # 当前部分的结束索引
#         end = start + avg_size + (1 if i < extra else 0)  # 前 extra 部分分配一个额外元素
#         result.append(np.mean(np.array(lst[start:end])))
#         start = end

#     return result




# # count_scores = {}
# # for item in reward_process_data:
# #     golden_score = golden_reward_process_data[reward_process_data.index(item)]['step_scores_golden'][0]
# #     golden_score = split_list_evenly(golden_score,5)
# #     if not any(item['corrects']):
# #         continue
# #     count = item['corrects'].count(True)
# #     if count not in count_scores:
# #         count_scores[count] = []
# #     score = [item['step_scores'][idx] for idx in range(roll_num) if item['corrects'][idx]]
# #     for s in score:
# #         s = split_list_evenly(s,5)
# #         # print(golden_score)
# #         s = [a - b for a,b in zip(golden_score, s)]
# #         count_scores[count].append(s)
# # for count, score in count_scores.items():
# #     level = count.bit_length()
# #     for s in score:
# #         steps += [1,2,3,4,5]
# #         scores += s
# #         labels += [level] * 5
    
#     # if item['cor_flag']:
#     #     good_score = item['step_scores'][item['corrects'].index(True)]
#     #     steps += [1,2,3,4,5]
#     #     scores += split_list_evenly(good_score, 5)
#     #     labels += ['correct'] * 5 
#     #     good_score = item['step_scores'][item['corrects'].index(False)]
#     #     steps += [1,2,3,4,5]
#     #     scores += split_list_evenly(good_score, 5)
#     #     labels += ['incorrect'] * 5 
#     # else: 
#     #     if not any(item['corrects']):
#     #         continue
#     #     good_score = item['step_scores'][item['corrects'].index(True)]
#     #     steps += [1,2,3,4,5]
#     #     scores += split_list_evenly(good_score, 5)

#     #     labels += ['recall'] * 5 
#     #     bad_score = item['step_scores'][item['answer'].index(item['pred'])]
#     #     steps += [1,2,3,4,5]
#     #     scores += split_list_evenly(bad_score, 5)
#     #     labels += ['hard'] * 5
        
# # for k in range(1,129):
# #     correct_counts = []
# #     incorrect_counts = []
# #     for item in sc_data:
# #         count = len(set(item['answer'][:k]))
# #         correct_count = any(item['corrects'][:k])
# #         incorrect_count = count - correct_count
# #         correct_counts.append(correct_count)
# #         incorrect_counts.append(incorrect_count)
# #     steps.append(k)
# #     scores.append(np.mean(np.array(correct_counts)))
# #     labels.append('correct')
# #     steps.append(k)
# #     scores.append(np.mean(np.array(incorrect_counts)))
# #     labels.append('incorrect')

# def split_difficulty(dataset):
#     difficulty_dic = {}
#     sc_path = f'./result/{dataset}/{model}/sc10_e3_{n_samples}.json'
#     sc_result = load_json_data(sc_path)[:-1]
#     for item in sc_result:
#         id = item['id']
#         difficulty = 6 - item['corrects'].count(True) // 2
#         if difficulty in difficulty_dic.keys():
#             difficulty_dic[difficulty].append(id)
#         else:
#             difficulty_dic[difficulty] = [id]
#     # difficulty_dic['all'] = [item['id'] for item in sc_result]
#     return difficulty_dic   


# def cal_pred(answer, score1, score2, roll_num):
#     coef1 = {}
#     coef2 = {}
#     for i in range(roll_num):
#         if answer[i] not in coef1.keys():
#             coef1[answer[i]] = [score1[i]]
#             coef2[answer[i]] = [score2[i]]
#             # coef3[answer[i]] = [score3[i]
#         else:
#             coef1[answer[i]].append(score1[i])
#             coef2[answer[i]].append(score2[i])
#             # coef3[answer[i]].append(score3[i])
    
#     # counts = Counter(answer)
#     # max_answer = max(counts, key=lambda x:counts[x])
#     # if max_answer == None:
#     #     del counts[None]
#     # if not counts:
#     #     return None 
#     # max_count = max(counts.values())
    
#     # for k, v in coef1.items():
#     #     if counts[k] < 0.1 * max_count:
#     #         coef1[k] = -100
#     #     else:
#     #         coef1[k] = np.sum(np.array(sorted(v)))
            
#     # # for k, v in coef2.items():
#     # #     if counts[k] < 0.1 * max_count:
#     # #         coef2[k] = -100
#     # #     else:
#     # #         coef2[k] = np.max(np.array(sorted(v)))
            
#     # # for k, v in coef3.items():
#     # #     if k == None:
#     # #         coef3[k] = -100
#     # #     else:
#     # #         coef3[k] = np.max(np.array(sorted(v)))
#     # # idx = np.argmax(np.array(scores))
#     # pred1 = max(coef1, key=lambda x: coef1[x])
#     # # pred2 = max(coef2, key=lambda x: coef2[x])
#     # # pred3 = max(coef3, key=lambda x: coef3[x])
#     # return pred1
#     # preds = [pred1, pred2]
#     # lens = [counts[pred] for pred in preds]
#     # best_idx = np.argmax(np.array(lens))
    
#     # return preds[best_idx]  

# # data_dic = {'skywork':bestn1_data, 'skyworko1':bestn3_data, 'self':bestn2_data}
 
# # difficulty_dic = split_difficulty('math')

# # # for name, data in data_dic.items():
# # cnt = 0
# # cor = 0
# # roll_num = 128
# # for dif, index in difficulty_dic.items():    
# #     result1 = [item for item in bestn1_data if item['id'] in index]
# #     result2 = [item for item in bestn3_data if item['id'] in index]
# #     result3 = [item for item in bestn4_data if item['id'] in index]
# #     for item in result1:
# #         answer = item['answer'][:roll_num]
# #         score1 = [item['response'][i]['score'] for i in range(roll_num)] 
# #         score2 = [result2[result1.index(item)]['response'][i]['score'] for i in range(roll_num)] 
# #         score3 = [result3[result1.index(item)]['response'][i]['score'] for i in range(roll_num)] 
       
# #         # if dif <= 2:
# #         #     pred = max(answer, key=answer.count)
# #         # else:
# #         pred = cal_pred(answer, score1, score2, roll_num)
# #         if is_equiv(pred, item['label'], 'math'):
# #             cor += 1
# #         cnt += 1
# # # print(name)
# # print(cor/cnt)
# # # difficulty_dic = split_difficulty('math')
# # for dif, index in difficulty_dic.items():
# #     steps = []
# #     scores = []
# #     labels = []
# #     data = [item for item in bestn3_data if item['id'] in index]
# #     for item in data:
# #         for k in range(2,129):
# #             index = random.sample(range(128), k)
# #             c_scores = [item['response'][idx]['score'] for idx in index if item['corrects'][idx]]
            
# #             w_scores = [item['response'][idx]['score'] for idx in index if not item['corrects'][idx]]
# #             if not c_scores or not w_scores:
# #                 continue
# #             steps += [k] * 3
# #             scores.append(np.max(np.array(c_scores)))
# #             labels.append('correct')
# #             scores.append(np.max(np.array(w_scores)))
# #             labels.append('incorrect')
# #             scores.append(np.max(np.array(c_scores)) - np.max(np.array(w_scores)))
# #             labels.append('gap')

# #     data = {'step':steps, 'score':scores, 'type':labels}
# #     data = pd.DataFrame(data, columns=['step', 'score', 'type'])
# #     path = f'fig/math/{model}/{dif}/skyworko1_gap.pdf'
# #     draw_line(data, path)
#         # print(good_scores)
#         # print(bad_scores)






# # for item in sc_data:
# #     if True in item['corrects'] and not item['cor_flag']:
# #         print(item['id'])
# # from tqdm import tqdm
# # files = list_all_files('/mnt/userdata/ljc/code/o1_cot/result/math/Qwen2_5_3b_chat')
# # for path in tqdm(files):  
# #     if 'mcts' not in path or 'reward' not in path:
# #         continue
# #     data = load_json_data(path)
# #     cor = 0
# #     cnt = 0    
# #     for item in data:
# #         if 'acc' in item.keys():
# #             item['acc'] = cor / cnt 
# #             break
# #         answer = item['answer']
# #         pred = item['pred']
# #         if isinstance(item['response'][0], dict):
# #             answer = [extract_answer(res['content'], 'math') for res in item['response']]
# #         else:
# #             answer = [extract_answer(res, 'math') for res in item['response']]
# #         item['corrects'] = [is_equiv(ans, item['label'], 'math') for ans in answer]
# #         item['answer'] = answer 
# #         if 'best' in path:
# #             scores = [tup['score'] for tup in item['response']]
# #             pred = answer[np.argmax(np.array(scores))]
# #         elif 'beam' in path:
# #             pred = answer[0]
# #         elif 'reward' in path:
# #             if 'trace' not in item.keys() or not item['trace']:
# #                 solution =  max(item['response'][:-1], key=lambda x: x['reward'])['content']
# #             else:
# #                 solution =  max(item['response'], key=lambda x: x['reward'])['content']
# #             pred = extract_answer(solution, dataset='math')
# #         elif 'q_value' in path:
# #             if 'trace' not in item.keys() or not item['trace']:
# #                 solution =  max(item['response'][:-1], key=lambda x: x['q_value'])['content']
# #             else:
# #                 solution =  max(item['response'], key=lambda x: x['q_value'])['content']
# #             pred = extract_answer(solution, dataset='math')
# #         else:
# #             pred = max(answer, key=answer.count)
# #         item['cor_flag'] = is_equiv(pred, item['label'], 'math')
# #         cor += int(item['cor_flag'])
# #         if 'old_answer' in item.keys():
# #             del item['old_answer']
# #         item['pred'] = pred
# #         cnt += 1
# #     write_json_data(path, data)

# roll_num = 16
# # 

# # cor = 0
# # for i in range(200):
# #     if mcts_data[i]['cor_flag'] or bestn_data[i]['cor_flag']:
# #         cor += 1

# # for i in range(n_samples):
# #     # if True in reward_sc_data[i]['corrects'] and not reward_sc_data[i]['cor_flag']:
# #     #     print(mcts_data[i]['id'])
# #     mcts_flg = mcts_data[i]['cor_flag']
# #     bestn_flg = bestn_data[i]['cor_flag']
# #     if mcts_flg and not bestn_flg:
# #         print(f"mcts win:{mcts_data[i]['id']}")
# #     elif not mcts_flg and bestn_flg:
# #         print(f"bestn win:{bestn_data[i]['id']}")
        
# path = '/mnt/userdatdsa/ljc/code/o1_cot/result/math/Qwen2_5_3b_chat/mcts64_t0.7_d8_w0.1_3_8_self-Qwen2_5_3b_chat_reward_e3_200.json'
# list = ['reward_sc', 'sc', 'q_value', 'bestN', 'bestQ']
# for select in list:
#     extract_mcts_answer(path, select)
# # print(cor / 200)
# # data_ls = []
# # list = ['sc', 'reward']
# # for select in list:
# #     path = f'/mnt/userdata/ljc/code/o1_cot/result/math/Qwen2_5_3b_chat/mcts16_t0.7_d5_w5.0_5_1_skywork_{select}_e3_200.json'
# #     data = load_json_data(path)[:-1]
# #     data_ls.append(data)
# # sc_path = f'./result/math/Qwen2_5_3b_chat/sc128_e3_200.json'
# # data_ls.append(load_json_data(sc_path)[:-1])
# # bestn_path = f'./result/math/Qwen2_5_3b_chat/best128_skywork_e3_200.json'
# # data_ls.append(load_json_data(bestn_path)[:-1])

# # cor_flag = 0
# # for i in range(200):
# #     cor = any([data[i]['cor_flag'] for data in data_ls])
# #     cor_flag += int (cor)
# # print(cor_flag / 200)
import numpy as np 
def get_acc(result:list[dict], n_samples:int, method:str, dataset) -> list[dict]:
    correct = 0
    index = range(0, n_samples)
    new_results = []
    for item in result:
        responses = [item['response'][idx] for idx in index]
        answers = [item['answer'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]
        
        if method == 'sc':
            best_idx = answers.index(max(answers, key=answers.count))
            cor_flag = corrects[best_idx]
            correct += int(cor_flag)
        else:
            if method == 'bon':
                best_idx = max(enumerate(responses), key=lambda x: x[1]['score'])[0]
                cor_flag = corrects[best_idx]
                correct += int(cor_flag)
            elif method == 'mcts':
                best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
                cor_flag = corrects[best_idx]
                correct += int(cor_flag)
            else:
                if method == 'ds':
                    scores = [tup['score'] if tup['score'] else -1 for tup in item['trace'].values()]
                else:
                    scores = [tup['score'] for tup in responses]
                # print(scores)
                mean_score = np.mean(np.array(scores))
                std_score = np.std(np.array(scores))
                scores = [(score-mean_score)/std_score for score in scores]
                coef = {}
                for i in range(len(answers)):
                    if answers[i] not in coef.keys():
                        coef[answers[i]] = scores[i]
                    else:
                        coef[answers[i]] += scores[i]
                pred = max(coef, key=lambda x: coef[x])
                # if not pred:
                #     print('>>>')
                # if method == 'ds' and pred:
                #     pred = pred.strip().rstrip('\\')'
                cor_flag = is_equiv(pred, item['label'], dataset)
                correct += int(cor_flag)
        new_results.append({'response':responses, 'answer':answers, 'corrects':corrects, 'label':item['label'], 'cor_flag':cor_flag})
    new_results.append({'acc':correct / len(result)})
    return new_results
from tqdm import tqdm
methods = ['bon_orm', 'orm_sc']
datasets = [ 'olympiadbench']
models = ['Llama3_1_8b_chat']
for dataset in datasets:
    for model in models:
        for method in tqdm(methods):
            if method == 'sc':
                new_path = f'./result/{dataset}/{model}/sc32_t0.7_e3_200.json'
                # if os.path.exists(new_path):
                #     continue
                old_path = f'./result/{dataset}/{model}/sc128_e3_200.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 32, 'sc', dataset)
                write_json_data(new_path, new_result)
            elif method == 'cot':
                new_path = f'./result/{dataset}/{model}/cot_t0.7_e3_200.json'
                old_path = f'./result/{dataset}/{model}/sc32_t0.7_e3_200.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 1, 'sc', dataset)
                write_json_data(new_path, new_result)
            elif method == 'bon_orm':
                new_path = f'./result/{dataset}/{model}/best32_t0.7_skywork_e3_200.json'
                # if os.path.exists(new_path):
                #     continue
                old_path = f'./result/{dataset}/{model}/best32_t0.7_skywork_e3_500.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 32, 'bon', dataset)
                write_json_data(new_path, new_result)
            elif method == 'bon_prm':
                new_path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_500.json'
                # if os.path.exists(new_path):
                    # continue
                old_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_500.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 32, 'bon', dataset)
                write_json_data(new_path, new_result)
            elif method == 'orm_sc':
                new_path = f'./result/{dataset}/{model}/sc32_t0.7_skywork_e3_200.json'
                # if os.path.exists(new_path):
                #     continue
                old_path = f'./result/{dataset}/{model}/best32_t0.7_skywork_e3_200.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 32, 'sc_bon', dataset)
                write_json_data(new_path, new_result)
            elif method == 'prm_sc':
                new_path = f'./result/{dataset}/{model}/sc32_t0.7_skyworko1_e3_200.json'
                # if os.path.exists(new_path):
                #     continue
                old_path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_200.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 32, 'sc_bon', dataset)
                write_json_data(new_path, new_result)
            elif method == 'mcts_orm':
                new_path = f'./result/{dataset}/{model}/mcts16_t0.7_d5_w5.0_5_1_skywork_reward_e3_500.json'
                # if os.path.exists(new_path):
                #     continue
                old_path = f'./result/{dataset}/{model}/mcts32_t0.7_d5_w5.0_5_1_skywork_reward_e3_500.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 16, 'mcts', dataset)
                write_json_data(new_path, new_result)
            elif method == 'mcts_prm':
                new_path = f'./result/{dataset}/{model}/mcts16_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_500.json'
                # if os.path.exists(new_path):
                #     continue
                old_path = f'./result/{dataset}/{model}/mcts32_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_500.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 16, 'mcts', dataset)
                write_json_data(new_path, new_result)
            elif method == 'ds':
                new_path = f'./result/{dataset}/{model}/woc10_w1_d2_skywork_e3_100.json'
                old_path = f'./result/{dataset}/{model}/woc10_w1_d2_skywork_e3_100_old.json'
                new_result = get_acc(load_json_data(old_path)[:-1], 10, 'ds', dataset)
                write_json_data(new_path, new_result)