from utils.load_data import load_json_data, extract_answer, write_json_data
from utils.eval import is_equiv
from utils.metrics import draw_line, draw_bar, draw_scatter, draw_heat, draw_box, draw_sub_bar
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import argparse
import random
import pandas as pd
import numpy as np 
import os
from tqdm import tqdm
from cal_js_div import compute_js_divergence_matrix
from scipy.stats import entropy
from utils.model import ModelWrapper
random.seed(17)

def get_roll_process_reward(result:list[dict], n_samples:int, golden: bool = False):
    correct_reward = []
    wrong_reward = []
    for item in result:
        if golden:
            correct_reward.append(np.min(np.array(item['step_scores_golden'])))
        else:
            index = random.sample(list(range(0,roll_num)), n_samples)
            corrects = [item['corrects'][idx] for idx in index]
            reward = [item['step_scores'][idx] for idx in index]
            for i in range(n_samples):
                if corrects[i]:
                    correct_reward.append(np.min(np.array(reward[i])))
                else:
                    wrong_reward.append(np.min(np.array(reward[i])))
    return correct_reward, wrong_reward


def get_roll_n_acc(result:list[dict], n_samples:int, golden:bool) -> float:
    cor_flag = 0
    if isinstance(result[0]['response'][0], str) or len(result[0]['response'][0]) == 2:
        index = random.sample(list(range(0,roll_num)), n_samples) 
    else:
        index = range(0, n_samples)
    # print(result)
    
    for item in result:
        # index = random.sample(list(range(0,max_num)), m)
        # print(max_num)
        # print(n_samples)
        # print(index)
        responses = [item['response'][idx] for idx in index]
        answers = [item['answer'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]
        if golden:
            if True in corrects:
                best_idx = corrects.index(True)
            else:
                best_idx = 0
        elif isinstance(responses[0], dict):
            if len(responses[0]) == 3:
                best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
            else:
                best_idx = max(enumerate(responses), key=lambda x: x[1]['score'])[0]
        else:
            best_idx = answers.index(max(answers, key=answers.count))
        cor_flag += int(corrects[best_idx])
    return cor_flag / len(result)


def get_roll_topk_hit(result:list[dict], n_samples:int, topk:int) -> int:
    cor_flag = 0
    ranks = []
    for item in result:
        index = random.sample(list(range(0,roll_num)), n_samples)
        responses = [item['response'][idx] for idx in index]
        reward = isinstance(responses[0], dict)
        answers = [item['answer'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]
        if not any(corrects):
            continue
        if reward:
            if len(responses[0]) == 3:
                coef = {}
                for i in range(n_samples):
                    if answers[i] not in coef.keys():
                        coef[answers[i]] = responses[i]['score']
                    else:
                        coef[answers[i]] += responses[i]['score']
                scores = [{"element": ans, "score": coef[ans]} for ans in answers]
            else:
                scores = responses
        else:
            count_dict = Counter(answers)
            scores = [{"element": ans, "score": count_dict[ans]} for ans in answers]
        score = None 
        
        sorted_result = sorted(enumerate(scores), key=lambda x: x[1]["score"], reverse=True)
        rank = 1
        for idx, item in sorted_result:
            if corrects[idx]:
                break
            if score != scores[idx]['score']:
                rank += 1
                score = scores[idx]['score']
        # print(rank)
        ranks.append(rank)
        cor_flag += int(rank <= topk)
    return cor_flag / len(result), ranks


def get_roll_n_hit(result:list[dict], n_samples:int, reward:bool) -> int:
    cor_flag = 0
    for item in result:
        index = range(0, n_samples)
        responses = [item['response'][idx] for idx in index]
        answers = [item['answer'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]
        if not any(corrects):
            continue
        if reward:
            scores = responses
        else:
            count_dict = Counter(answers)
            scores = [{"element": ans, "score": count_dict[ans]} for ans in answers]
        sorted_result = sorted(enumerate(scores), key=lambda x: x[1]["score"], reverse=True)
        rank = 1
        for idx, item in sorted_result:
            if corrects[idx]:
                break
            rank += 1
        # print(rank)
        cor_flag += rank
        if rank <= 2:
            print(scores[:2])
    return cor_flag / len(result)


def acc_with_diff(data_dic, fig_name=False):
    def get_acc(result:list[dict], n_samples:int, method:str) -> float:
        cor_flag = 0
        index = range(0, n_samples)
        for item in result:
      
            responses = [item['response'][idx] for idx in index]
            answers = [item['answer'][idx] for idx in index]
            corrects = [item['corrects'][idx] for idx in index]
            if 'oracle' in method.lower():
                if True in corrects:
                    best_idx = corrects.index(True)
                else:
                    best_idx = 0
            elif 'sc' in method.lower():
                best_idx = answers.index(max(answers, key=answers.count))
            else:
                if len(responses[0]) == 3:
                    best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
                else:
                    best_idx = max(enumerate(responses), key=lambda x: x[1]['score'])[0]
            
            cor_flag += int(corrects[best_idx])
        return cor_flag / len(result)
    
    difficulty_dic = split_difficulty(dataset, all=False, model=model)
    methods = []
    nums = []
    scores = []
    for difficulty, index in difficulty_dic.items():

        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            nums.append(difficulty)
            methods.append(name)
            scores.append(get_acc(result, 32, name))
        if not fig_name:
            result = [item for item in data_dic['sc'] if item['id'] in index]
            nums.append(difficulty)
            methods.append('Oracle')
            scores.append(get_acc(result, 32, 'oracle'))
               
 
    dir_path = f'fig/{dataset}/{model}/'
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    if fig_name:
        path = dir_path + f'{fig_name}_diff_acc.pdf'
    else:
        path = dir_path + f'diff_acc.pdf'
    data = {'difficulty':nums, 'accuracy':scores, 'method':methods}
    data = pd.DataFrame(data, columns=list(data.keys()))
    draw_bar(data, path)

def acc_with_nums(data_dic, fig_name=None, recall=False):
    
    roll_out_nums = range(1, roll_num+1)
    # if 'mcts' in data_dic.keys():
    # roll_out_nums = range(4, roll_num+1, 4) 
    methods = []
    nums = []
    scores = []
    for num in roll_out_nums:
        if recall:
            if 'SC' in data_dic.keys():
                nums.append(num)
                methods.append('Oracle')
                scores.append(get_roll_n_acc(data_dic['SC'], num, True))
          
            if 'MCTS_ORM' in data_dic.keys():  
                nums.append(num)
                methods.append('Oracle_ORM')
                scores.append(get_roll_n_acc(data_dic['MCTS_ORM'] , num, True))
                
            if 'MCTS_PRM' in data_dic.keys():
                nums.append(num)
                methods.append('Oracle_PRM')
                scores.append(get_roll_n_acc(data_dic['MCTS_PRM'], num, True))
      
            
        for name, data in data_dic.items():
            nums.append(num)
            methods.append(name)
            scores.append(get_roll_n_acc(data, num, False))
        # if recall:
               
 
    dir_path = f'fig/{dataset}/{model}/'
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    if 'MCTS_PRM' in data_dic.keys():
        if not fig_name:
            if recall:
                path = dir_path + f'search_recall.pdf'
            else:
                path = dir_path + f'search_acc.pdf'
        else:
            path = dir_path + f'{fig_name}_search_acc.pdf'
        data = {'N':nums, 'accuracy':scores, 'method':methods}
        data = pd.DataFrame(data, columns=['N', 'accuracy', 'method'])
        draw_line(data, path, False)
    else:
        if fig_name:
            path = dir_path + f'{fig_name}_acc.pdf'
            data = {'N':nums, 'accuracy':scores, 'method':methods}
            data = pd.DataFrame(data, columns=['N', 'accuracy', 'method'])
            draw_line(data, path, True)
    
        else:
            path = dir_path + f'sample_acc.pdf'
            data = {'N':nums, 'accuracy':scores, 'method':methods}
            data = pd.DataFrame(data, columns=['N', 'accuracy', 'method'])
            draw_line(data, path, False)


def hit_rank_with_nums(dataset='math'):
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        roll_out_nums = range(2, roll_num)
        methods = []
        nums = []
        scores = []
        sc_path = f'./result/{dataset}/{model}/sc{roll_num}_e3_{n_samples}.json'
        sc_result = load_json_data(sc_path)[:-1]
        sc_result = [item for item in sc_result if item['id'] in index]
        for num in roll_out_nums:
            nums.append(num)
            methods.append('sc')
            scores.append(get_roll_n_hit(sc_result, num, False))
        for reward in rewards:
            reward_path = f'./result/{dataset}/{model}/best{roll_num}_{reward}_e3_{n_samples}.json'
            reward_result = load_json_data(reward_path)[:-1]
            reward_result = [item for item in reward_result if item['id'] in index]
            for num in roll_out_nums:
                nums.append(num)
                methods.append(reward)
                scores.append(get_roll_n_hit(reward_result, num, True))
        path = f'fig/{dataset}_{difficulty}.png'
        data = {'roll_out_num':nums, 'accuracy':scores, 'method':methods}
        data = pd.DataFrame(data, columns=['roll_out_num', 'accuracy', 'method'])
        draw_line(data, path)


def hit_rank_with_topk(data_dic):
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        roll_out_nums = range(2, roll_num)
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            result_dic = {}
            topks = []
            nums = []
            scores = []
            for num in roll_out_nums:
                for k in [1,5,10,20,50]:
                    nums.append(num)
                    topks.append(k)
                    score, rank = get_roll_topk_hit(result, num, k)
                    scores.append(score)
                result_dic[num] = rank
            dir_path = f'fig/{dataset}/{model}/d{difficulty}/'
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)
            path = dir_path + f'{name}_hit.pdf'
            data = {'roll_out_num':nums, 'accuracy':scores, 'topk':topks}
            data = pd.DataFrame(data, columns=['roll_out_num', 'accuracy', 'topk'])
            draw_line(data, path)


def split_difficulty(dataset, all=False, model='Qwen2_5_3b_chat'):
    difficulty_dic = {}
    sc_path = f'./result/{dataset}/{model}/sc10_e3_{n_samples}.json'
    sc_result = load_json_data(sc_path)[:-1]
    for item in sc_result:
        id = item['id']
        difficulty = 5 - item['corrects'][:10].count(True) // 2
        if difficulty == 0:
            difficulty = 1
        if difficulty in difficulty_dic.keys():
            difficulty_dic[difficulty].append(id)
        else:
            difficulty_dic[difficulty] = [id]
    if all:
        difficulty_dic['all'] = [item['id'] for item in sc_result]
    return difficulty_dic   

def acc_with_models(data_dic, dataset):
    methods = []
    model_ls = []
    scores = []
    for model, data in data_dic.items():
        for name, result in data.items():
            model_ls.append(model)
            methods.append(name)
            scores.append(get_roll_n_acc(result, 10, False))
        model_ls.append(model)
        methods.append('Oracle')
        scores.append(get_roll_n_acc(data['SC'], 10, True))
    
    dir =  f'fig/{dataset}/'
    if not os.path.exists(dir):
        os.makedirs(dir)
    path = dir + 'model_acc.pdf'
    data = {'model':model_ls, 'accuracy':scores, 'method':methods}
    data = pd.DataFrame(data, columns=['model', 'accuracy', 'method'])
    draw_sub_bar(data, path)



def process_with_difficulty():    
    dataset = 'math'
    roll_out_nums = range(1, roll_num+1)
    sc_path = f'./result/{dataset}/{model}/sc{roll_num}_e3_{n_samples}_skyworko1.json'
    golden_path = f'./result/{dataset}/{model}/sc{roll_num}_e3_{n_samples}_skyworko1_golden.json'
    sc_result = load_json_data(sc_path)
    golden_results = load_json_data(golden_path)
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        result = [item for item in sc_result if item['id'] in index]
        golden_result = [item for item in golden_results if item['id'] in index]
        methods = []
        nums = []
        scores = []
        for num in roll_out_nums:
            correct_score, wrong_score = get_roll_process_reward(result, num)
            nums += (len(correct_score) + len(wrong_score)) * [num]
            methods += len(correct_score) * ['correct']
            scores += correct_score
            methods += len(wrong_score) * ['wrong']
            scores += wrong_score
            golden_score, _ = get_roll_process_reward(golden_result, num, True)
            # print(golden_score)
            scores += golden_score
            methods += len(golden_score) * ['golden']
            nums += len(golden_score) * [num]
       
        path = f'fig/{model}_{dataset}_{difficulty}_process.png'
        data = {'roll_out_num':nums, 'accuracy':scores, 'method':methods}
        data = pd.DataFrame(data, columns=['roll_out_num', 'accuracy', 'method'])
        draw_line(data, path)


def find_valid_solution_nodes(root_node):
    valid_solution_nodes = []

    def recursion(node):
        if node.is_valid_solution_node():
            valid_solution_nodes.append(node)
            return

        if not node.children:  #! no children
            return

        for child in node.children:
            recursion(child)
    recursion(root_node)


def get_reward_acc(result, n_samples):
    cor_flag = []
    for item in result:
        index = random.sample(list(range(0,roll_num)), n_samples)
        responses = [item['response'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]

        if all(corrects) or not any(corrects):
            continue
        else:
            cor_scores = [responses[idx]['score'] for idx in range(n_samples) if corrects[idx]]
            ic_scores = [responses[idx]['score'] for idx in range(n_samples) if not corrects[idx]]
            # cor_idx = corrects.index(True)
            # ic_idx = corrects.index(False)
            cor_score = np.max(np.array(cor_scores))
            ic_score = np.max(np.array(ic_scores))

        cor_flag.append(int(cor_score > ic_score))

    return cor_flag


def get_llm_acc(data_dic):
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        nums = []
        names = []
        scores = []
        roll_out_nums = range(2, roll_num)
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for num in roll_out_nums:
                score = get_reward_acc(result, num)
                scores += score
                names += [name] * len(score)
                nums += [num] * len(score)
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        path = dir_path + f'reward_acc.pdf'
        data = {'roll_out_num':nums, 'accuracy':scores, 'reward':names}
        data = pd.DataFrame(data, columns=['roll_out_num', 'accuracy', 'reward'])
        draw_line(data, path)



def reward_acc(data_dic):
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        nums = []
        names = []
        scores = []
        roll_out_nums = range(2, roll_num)
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for num in roll_out_nums:
                score = get_reward_acc(result, num)
                scores += score
                names += [name] * len(score)
                nums += [num] * len(score)
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        path = dir_path + f'reward_acc.pdf'
        data = {'roll_out_num':nums, 'accuracy':scores, 'reward':names}
        data = pd.DataFrame(data, columns=['roll_out_num', 'accuracy', 'reward'])
        draw_line(data, path)


def get_cluster_count(result, n_samples):
    cor_flag = []
    for item in result:
        index = random.sample(list(range(0,roll_num)), n_samples)
        responses = [item['response'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]

        if all(corrects) or not any(corrects):
            continue
        else:
            cor_idx = corrects.index(True)
            ic_idx = corrects.index(False)
            cor_score = responses[cor_idx]['score']
            ic_score = responses[ic_idx]['score']

        cor_flag.append(int(cor_score > ic_score))

    return np.mean(np.array(cor_flag))


def get_cluster_avg_score(result, n_samples):
    cor_flag = []
    for item in result:
        index = random.sample(list(range(0,roll_num)), n_samples)
        responses = [item['response'][idx] for idx in index]
        corrects = [item['corrects'][idx] for idx in index]

        if all(corrects) or not any(corrects):
            continue
        else:
            cor_idx = corrects.index(True)
            ic_idx = corrects.index(False)
            cor_score = responses[cor_idx]['score']
            ic_score = responses[ic_idx]['score']

        cor_flag.append(int(cor_score > ic_score))

    return np.mean(np.array(cor_flag))



def reward_distribution(data_dic, roll_num, reward):
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        if len(reward) >= 2:
            result1 = [item for item in data_dic[reward[0]] if item['id'] in index]
            result2 = [item for item in data_dic[reward[1]] if item['id'] in index]
            for item in result1:
                id = item['id']
                correct = item['corrects'][:roll_num]
                if not any(correct):
                    continue
                clusters = item['answer'][:roll_num]
                scores1 = [tup['score'] for tup in item['response'][:roll_num]]
                scores2 = [tup['score'] for tup in result2[result1.index(item)]['response'][:roll_num]]
                styles = []
                for idx in range(roll_num):
                    if correct[idx]:
                        styles.append('cr')
                    else:
                        styles.append('ic')             
                dir_path = f'fig/{dataset}/{model}/{difficulty}/'
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                path = dir_path + f'{id}_{roll_num}_distribution.pdf'
                data = {reward[0]:scores1, reward[1]:scores2, 'cluster':clusters, 'style':styles}
                data = pd.DataFrame(data, columns=[reward[0], reward[1], 'cluster', 'style'])
                draw_scatter(data, path)
        else:
            result = [item for item in data_dic[reward[0]] if item['id'] in index]
            for item in result:
                id = item['id']
                correct = item['corrects'][:roll_num]
                clusters = item['answer'][:roll_num]
                for i in range(roll_num):
                    if not clusters[i]:
                        clusters[i] = 'None'
                    else:
                        clusters[i] = clusters[i].replace('$', '').replace('\\','')
                scores = [tup['score'] for tup in item['response'][:roll_num]]
                element_counts = Counter(clusters)
                cnts = [element_counts[element] for element in clusters]
                styles = []
                for idx in range(roll_num):
                    if correct[idx]:
                        styles.append('cr')
                    else:
                        styles.append('ic')             
                dir_path = f'fig/{dataset}/{model}/{difficulty}/'
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                path = dir_path + f'{id}_{reward[0]}_{roll_num}_distribution.pdf'
                data = {reward[0]:scores, 'count':cnts, 'cluster':clusters, 'style':styles}
                data = pd.DataFrame(data, columns=[reward[0], 'count', 'cluster', 'style'])
                draw_scatter(data, path, True)


def dataset_dif_stat(data_dic):
    def get_acc(result:list[dict], n_samples:int, method:str) -> float:
        cor_flag = 0
        index = range(0, n_samples)
        for item in result:
            # print(item['response'])
            responses = [item['response'][idx] for idx in index]
            answers = [item['answer'][idx] for idx in index]
            corrects = [item['corrects'][idx] for idx in index]
            if 'oracle' in method.lower():
                if True in corrects:
                    best_idx = corrects.index(True)
                else:
                    best_idx = 0
            elif 'sc' in method.lower():
                best_idx = answers.index(max(answers, key=answers.count))
            else:
                if len(responses[0]) == 3:
                    best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
                else:
                    best_idx = max(enumerate(responses), key=lambda x: x[1]['score'])[0]
            
            cor_flag += int(corrects[best_idx])
        return cor_flag / len(result)
    
    
    for dataset, method_data in data_dic.items():
        print(dataset)        
        difficulty_dic = split_difficulty(dataset, all=True, model=model)
        
        for difficulty, index in difficulty_dic.items():
            # if difficulty == 2 or difficulty == 4:
            #     continue
            print(difficulty)
            for name, data in method_data.items():
                result = [item for item in data if item['id'] in index]
                acc = get_acc(result, roll_num, name)
                print(f'{name}: {acc}')
            result = [item for item in method_data['sc'] if item['id'] in index]
            print(f'counts: {len(result)}')
            # print(f"oracle: {get_acc(result, roll_num, 'oracle')}")
               
 

def acc_with_temperature(data_dic):
    # difficulty_dic = split_difficulty(dataset, all=True, model=model)
    # for difficulty, index in difficulty_dic.items():
    methods = []
    t_ls = []
    scores = []
    for t, data in data_dic.items():
        for name, result in data.items():
            # print(result)
            t_ls.append(t)
            methods.append(name)
            scores.append(get_roll_n_acc(result, 16, False))
        t_ls.append(t)
        methods.append('Oracle')
        scores.append(get_roll_n_acc(data['SC'], 16, True))
        
    dir_path = f'fig/{dataset}/{model}/'
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)    
    path = dir_path + f'temp_acc.pdf'
    data = {'T':t_ls, 'accuracy':scores, 'method':methods}
    data = pd.DataFrame(data, columns=['T', 'accuracy', 'method'])
    draw_bar(data, path)


def distribution_with_temperature(data_dic):
    difficulty_dic = split_difficulty(dataset, all=True, model=model)
    for difficulty, index in difficulty_dic.items():
        for name, data in data_dic.items():
            temperatures = []
            scores = []
            labels = []
            for t, data in data.items():
                result = [item for item in data if item['id'] in index]
                for item in result:
                    correct = item['corrects']
                    score = [tup['score'] for tup in item['response']]
                    if any(correct):
                        scores.append(max([score[i] for i in range(roll_num) if correct[i]]))
                        temperatures.append(t)
                        labels.append('correct')
                    if not all(correct):
                        scores.append(max([score[i] for i in range(roll_num) if not correct[i]]))
                        temperatures.append(t)
                        labels.append('incorrect')
                    
            dir_path = f'fig/{dataset}/{model}/{difficulty}/'
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)
            path = dir_path + f'{name}_temp_distribution.pdf'
            data = {'temperature':temperatures, 'score':scores, 'correct':labels}
            data = pd.DataFrame(data, columns=list(data.keys()))
            draw_box(data, path)


# def reward_with_temperature(data_dic):
#     difficulty_dic = split_difficulty(dataset, model=model)
    
#     for difficulty, index in difficulty_dic.items():
#         for name, data in data_dic.items():
#             result = [item for item in data if item['id'] in index]
#             scores = []
#             labels = []
#             temperatures = []
#             for item in result:
#                 correct = item['corrects']
#                 correct_index = [idx for idx in range(len(correct)) if correct[idx]]
#                 score = [tup['score'] for tup in item['response']]
#                 for i in range(0, len(score), roll_num):
#                     cnt = len([score[idx] for idx in range(i, i+roll_num) if idx in correct_index])
#                     if cnt == 0 or cnt == roll_num:
#                         continue
#                     correct_score = max([score[idx] for idx in range(i, i+roll_num) if idx in correct_index])
#                     incorrect_score = max([score[idx] for idx in range(i, i+roll_num) if idx not in correct_index])
#                     scores.append(correct_score-incorrect_score)
#                     temperatures += [(i // roll_num + 1) * 0.1] * 4
#             dir_path = f'fig/{dataset}/{model}/{difficulty}/'
#             if not os.path.exists(dir_path):
#                 os.makedirs(dir_path)
#             path = dir_path + f'{name}_temp_stat.pdf'
#             data = {'temperature':temperatures, 'score':scores, 'label':labels}
#             data = pd.DataFrame(data, columns=list(data.keys()))
#             draw_line(data, path, True)


def reward_with_temperature(data_dic):
    difficulty_dic = split_difficulty(dataset, all=True, model=model)
    
    for difficulty, index in difficulty_dic.items():
        scores = []
        temperatures = []
        types = []
        for name, temp_data in data_dic.items():
            
            for t, result in temp_data.items():
                result = [item for item in result if item['id'] in index]
                for item in result:   
                    correct = item['corrects']
                    if all(correct):
                        scores.append(1)
                    elif not any(correct):
                        scores.append(0)
                    else:
                        score = [tup['score'] for tup in item['response']]
                        cor_score = max([score[i] for i in range(roll_num) if correct[i]])
                        inc_score = max([score[i] for i in range(roll_num) if not correct[i]])
                        scores.append(int(cor_score > inc_score))
                    temperatures.append(t)
                    types.append(name)
                    
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        path = dir_path + f'temp_reward.pdf'
        data = {'temperature':temperatures, 'score':scores, 'label':types}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_bar(data, path)
        

def sample_with_temperature(data_dic):
    # map = {'Qwen2.5-3B':'Qwen2_5_3b_chat', 'Llama3.1-8B':'Llama3_1_8b_chat'}
    scores = []
    temperatures = []
    types = []
    for name, temp_data in data_dic.items():
        # print(name)
        # difficulty_dic = split_difficulty(dataset, all=True, model=map[name])
        # for difficulty, index in difficulty_dic.items():
        for t, result in temp_data.items():
            # result = [item for item in result if item['id'] in index]
            for item in result:   
                correct = item['corrects']
                if all(correct):
                    continue
                answer = [item['answer'][i] for i in range(roll_num) if not correct[i] and item['answer'][i]]
                counts = Counter(answer)
                values = list(counts.values()) / np.sum(list(counts.values()))
                ent = entropy(values, base=2)
                scores.append(ent)
                temperatures.append(t)
                types.append(name)
    # print(types)            
    dir_path = f'fig/{dataset}/'
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    path = dir_path + f'temp_sample.pdf'
    data = {'T':temperatures, 'entropy':scores, 'label':types}
    data = pd.DataFrame(data, columns=list(data.keys()))
    draw_box(data, path)



def rollnum_timeline(data_dic, index_dic):
    for name, data in data_dic.items():
        scores = []
        index = index_dic[name]
        result = [data[i] for i in index]
        for item in result:
            corrects = []
            for k in range(1, roll_num+1):
                if isinstance(item['response'][0], dict):
                    if len(item['response'][0]) == 3:
                        score = [item['response'][i]['reward'] for i in range(k)]
                    else:
                        score = [item['response'][i]['score'] for i in range(k)]
                    idx = np.argmax(np.array(score)) 
                    cor_flags = item['corrects']
                    corrects.append(int(cor_flags[idx]))
                else:
                    answer = item['answer'][:k]
                    pred = max(answer, key=answer.count)
                    corrects.append(int(is_equiv(pred, item['label'], dataset)))
            scores.append(corrects)
        dir_path = f'fig/{dataset}/{model}/all/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)    
        path = dir_path + f'{name}_timeline.pdf'
        draw_heat(y_labels=range(1,6), x_labels=range(1,roll_num+1), scores=scores, path=path)



def mcts_reward_distribution(data_dic):
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for item in result:
                id = item['id']
                correct = item['corrects']
                scores = [tup['reward'] for tup in item['response']]
                cnt = list(range(1,roll_num+1))
                cluster = item['answer']
                for i in range(roll_num):
                    if not cluster[i]:
                        cluster[i] = 'None'
                    else:
                        cluster[i] = cluster[i].replace('$', '').replace('\\','')
                styles = []
                for idx in range(roll_num):
                    if correct[idx]:
                        styles.append('cr')
                    else:
                        styles.append('ic') 
                dir_path = f'fig/{dataset}/{model}/{difficulty}/'
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                path = dir_path + f'{id}_{name}_distribution.pdf'
                data = {'roll_num':cnt, 'score':scores, 'cluster':cluster, 'style':styles}
                data = pd.DataFrame(data, columns=list(data.keys()))
                draw_scatter(data, path, text=True)


                
def mcts_node_stat(data_dic, method='reward'):
    
    def draw_node_stat_fig(visits, q_values, nodes, simulations, path):

        # Create the figure
        fig, ax = plt.subplots(figsize=(15, 8))
        global_min = q_values.min()
        global_max = q_values.max()

        # 全局 Min-Max Scaling
        q_values = (q_values - global_min) / (global_max - global_min)

        # Loop through nodes and simulation steps to plot
        for i in range(nodes):
            for j in range(simulations):
                # Plot a rectangle for each node-step pair
                ax.add_patch(plt.Rectangle((j, i), 1, 1, 
                                            color=plt.cm.viridis(q_values[j, i]), 
                                            alpha=0.7, 
                                            linewidth=0))
                # Optionally, encode visits as text
                ax.text(j + 0.5, i + 0.5, str(visits[j, i]), 
                        color="white", ha="center", va="center", fontsize=4)

        # Set axis labels and ticks
        ax.set_xticks(np.arange(simulations) + 0.5)
        ax.set_xticklabels([f"Step {x}" for x in range(simulations)], rotation=90)
        ax.set_yticks(np.arange(nodes) + 0.5)
        ax.set_yticklabels([f"Node {x}" for x in range(nodes)])
        
        # Add legend for Q-values
        sm = plt.cm.ScalarMappable(cmap="viridis", norm=plt.Normalize(0, 1))
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation="vertical", shrink=0.8)
        cbar.set_label("Q-Values")

        # Set limits and titles
        ax.set_xlim(0, simulations)
        ax.set_ylim(0, nodes)
        ax.set_title("Heatmap of Nodes with Q-Values and Visits")
        ax.set_xlabel("Simulation Step")
        ax.set_ylabel("Nodes")

        plt.tight_layout()
        plt.savefig(path)
        plt.close()

    
    difficulty_dic = split_difficulty(dataset, model=model)
    for difficulty, index in difficulty_dic.items():
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for item in tqdm(result):
                traces = item['trace']
                del traces['-1']
                keys = set()
                for tup in traces.values():
                    keys.update(tup.keys())
                visits = []
                values = []
                for k, v in traces.items():
                    visits.append([v[id]['N'] if id in v.keys() else 0 for id in sorted(keys, key=int)])
                    values.append([v[id]['Q'] if id in v.keys() else 0 for id in sorted(keys, key=int)])
                    
                # q_value = [tup['q_value'] for tup in item['response']]
                dir_path = f'fig/{dataset}/{model}/{difficulty}/'
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                path = dir_path + f"{item['id']}_{name}_node_stat.pdf"
                draw_node_stat_fig(visits=np.array(visits), q_values=np.array(values), simulations=roll_num, nodes=len(keys), path=path)
                
           
                
def mcts_stat(data_dic, method='reward'):
    difficulty_dic = split_difficulty(dataset, model=model)
    
    for difficulty, index in difficulty_dic.items():
        scores = []
        labels = []
        nums = []
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for item in result:
                reward = [tup[method] for tup in item['response']]
                # q_value = [tup['q_value'] for tup in item['response']]
                scores += reward
                # scores += q_value
                labels += [method] * roll_num 
                nums += list(range(1,roll_num+1))
                
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        path = dir_path + f'{name}_{method}.pdf'
        data = {'nums':nums, 'score':scores, 'label':labels}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, True)

def mcts_call_acc(result, method, num):
    def travel_traces(traces, num, target):
        tree = traces[str(num)]
        node = tree['0']
        while node['children']:
            children = [tree[str(child)] for child in node['children']]
            node = max(children, key=lambda x:x[target])
        return node["answer"]
    
    
    cor_flag = 0
    cnt = 0
    for item in result:
        responses = item['response'][:num]
        answers = item['answer'][:num]
        corrects = item['corrects'][:num]
        if method == 'oracle':
            if True in corrects:
                best_idx = corrects.index(True)
            else:
                best_idx = 0
        elif method == 'sc':
            best_idx = answers.index(max(answers, key=answers.count))
        elif method == 'reward':
            best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
        elif method == 'q_value':
            best_idx = max(enumerate(responses), key=lambda x: x[1]['q_value'])[0]
        else:
            if method == 'N_greedy':
                solution = [travel_traces(item['trace'], i, 'N') for i in range(num)]
            else:
                solution = [travel_traces(item['trace'], i, 'Q') for i in range(num)]
            answer = [extract_answer(content, 'math') for content in solution]
            corrects = [is_equiv(ans, item['label'], dataset) for ans in answer]
            best_idx = num-1
        cor_flag += int(corrects[best_idx])
        cnt += 1
    return cor_flag / cnt

def append_mcts_acc(result, method):
    cor_flag = 0
    cnt = 0
    for item in result:
        responses = item['response']
        answers = item['answer']
        corrects = item['corrects']
        best_idx = -1
        correct = False
        if method == 'Oracle':
            if True in corrects:
                best_idx = corrects.index(True)
        elif method == 'Maj_vote':
            best_idx = answers.index(max(answers, key=answers.count))
        elif method == 'Reward':
            best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
        elif method == 'Q_value':
            best_idx = max(enumerate(responses), key=lambda x: x[1]['q_value'])[0]
        else:
            if method == 'N_greedy':
                solution = item['trace']["-1"]['bestN']
            else:
                solution =  item['trace']["-1"]['bestQ']
            answer = extract_answer(solution, 'math')
            correct = is_equiv(answer, item['label'], dataset)
        if best_idx >= 0:
            cor_flag += int(corrects[best_idx])
        else:
            cor_flag += int(correct)
        cnt += 1
    return cor_flag / cnt



def mcts_select_acc(data_dic, fig_name):
    difficulty_dic = split_difficulty('math', True)
    
    for difficulty, index in difficulty_dic.items():
        scores = []
        nums = []
        methods = []
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for num in range(1, roll_num+1, 2):
                scores.append(mcts_call_acc(result, name, num))
                nums.append(num)
                methods.append(name)
        # for num in range(1, roll_num+1, 2):
        #     scores.append(mcts_call_acc(data_dic['reward'], 'oracle', num))
        #     nums.append(num)
        #     methods.append('oracle')
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        path = dir_path + f'{fig_name}_mcts_select_acc.pdf'
        data = {'roll_out_num':nums, 'accuracy':scores, 'select':methods}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, True)


def mcts_acc_with_nums(data_dic, method='reward', fig_name=''):
    
    difficulty_dic = split_difficulty('math', True)
    
    for difficulty, index in difficulty_dic.items():
        scores = []
        nums = []
        methods = []
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for num in range(0, roll_num+1):
                scores.append(mcts_call_acc(result, method, num))
                nums.append(num)
                methods.append(name)
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        path = dir_path + f'{fig_name}.pdf'
        data = {'rollout_num':nums, 'accuracy':scores, 'reward':methods}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, True)


# def mcts_reward_acc(data_dic):
#     difficulty_dic = split_difficulty(dataset, all=True, model=model)
    
#     for difficulty, index in difficulty_dic.items():
#         for name, data in data_dic.items():
#             result = [item for item in data if item['id'] in index]
#             nums = []
#             scores = []
#             labels = []

#             for item in result:
#                 correct = item['corrects'][:roll_num]
#                 # labels += correct
#                 score = [tup['reward'] for tup in item['response'][:roll_num]]
#                 max_score = max(score)
#                 min_score = min(score)
#                 score = [(x - min_score) / (max_score - min_score) for x in score]
                
#                 for k in range(2, roll_num+1):
#                     cor_score = None 
#                     inc_score = None 
#                     if all(correct[:k]) or not any(correct[:k]):
#                         continue
#                     # if any(correct[:k]):
#                     cor_score = max([score[i] for i in range(k) if correct[i]])
#                     scores.append(cor_score)
#                     nums.append(k)
#                     labels.append('correct')
#                     # if not all(correct[:k]):
#                     inc_score = max([score[i] for i in range(k) if not correct[i]])
#                     scores.append(inc_score)
#                     nums.append(k)
#                     labels.append('incorrect')
#                     # if cor_score and inc_score:
#                     scores.append(cor_score-inc_score)
#                     nums.append(k)
#                     labels.append('gap')
#             # scores += score 
#                 # nums += list(range(1,roll_num+1))
            
#             # labels += ['gap'] * roll_num
#             # # scores += gap_score
#             # nums += list(range(1,roll_num+1))   
#             dir_path = f'fig/{dataset}/{model}/{difficulty}/'
#             if not os.path.exists(dir_path):
#                 os.makedirs(dir_path)
#             path = dir_path + f'{name}_mcts_reward_acc.pdf'
#             data = {'rollout_num':nums, 'score':scores, 'label':labels}
#             data = pd.DataFrame(data, columns=list(data.keys()))
#             draw_line(data, path, True)
            
            
def mcts_reward_acc(data_dic, comp=None):
    difficulty_dic = split_difficulty(dataset, all=True, model=model)
    
    for difficulty, index in difficulty_dic.items():
        nums = []
        scores = []
        labels = []
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for item in result:
                correct = item['corrects']
                for num in range(2, roll_num+1):
                    if len(item['response'][0]) == 2:
                        cor_scores = [item['response'][i]['score'] for i in range(num) if correct[i]]
                        inc_scores = [item['response'][i]['score'] for i in range(num) if not correct[i]]
                    else:
                        cor_scores = [item['response'][i]['reward'] for i in range(num) if correct[i]]
                        inc_scores = [item['response'][i]['reward'] for i in range(num) if not correct[i]]
                    if not cor_scores or not inc_scores:
                        continue
                    cor_score = max(cor_scores)
                    inc_score = max(inc_scores)
                    scores.append(int(cor_score > inc_score))
                    nums.append(num)
                    labels.append(name)
             
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        if not comp:
            path = dir_path + f'mcts_reward_acc.pdf'
        else:
            path = dir_path + f'{comp}_reward_acc.pdf'
        data = {'rollout_num':nums, 'accuracy':scores, 'label':labels}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, True)
        

def mcts_reward_incorrect(data_dic, comp=None):
    difficulty_dic = split_difficulty(dataset, all=True, model=model)
    
    for difficulty, index in difficulty_dic.items():
        nums = []
        scores = []
        labels = []
        for name, data in data_dic.items():
            result = [item for item in data if item['id'] in index]
            for item in result:
                correct = item['corrects']
                if len(item['response'][0]) == 2:
                    score = [item['response'][i]['score'] for i in range(roll_num)]
                else:
                    score = [item['response'][i]['reward'] for i in range(roll_num)]
                max_score = max(score)
                min_score = min(score)
                score = [(x - min_score) / (max_score - min_score) for x in score]
                for num in range(2, roll_num+1):
                    if all(correct[:num]):
                        continue
                    inc_score = max([score[i] for i in range(num) if not correct[i]])
                 
                    scores.append(inc_score)
                    nums.append(num)
                    labels.append(name)
             
        dir_path = f'fig/{dataset}/{model}/{difficulty}/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        if not comp:
            path = dir_path + f'mcts_inc_score.pdf'
        else:
            path = dir_path + f'{comp}_inc_score.pdf'
        data = {'rollout_num':nums, 'score':scores, 'label':labels}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, True)

def longtail_stat(data_dic, mcts=None):

    nums = []
    inc_counts = []
    labels = []
    for name, data in data_dic.items():
        score_counts = defaultdict(int)
        for item in data:
            correct = item['corrects'][:roll_num]
            if all(correct):
                continue
            answers = item['answer'][:roll_num]
            
            if len(item['response'][0]) == 2:
                scores = [tup['score'] for tup in item['response'][:roll_num]]
            else:
                scores = [tup['reward'] for tup in item['response'][:roll_num]]
                
            index = np.argmax(np.array([scores[i] for i in range(roll_num) if not correct[i]]))
            counts = Counter(answers)
            count = counts[answers[index]]
           
            score_counts[count] += 1
        for k, v in score_counts.items():
            nums.append(k)
            inc_counts.append(v)
            labels.append(name)
    dir_path = f'fig/{dataset}/{model}/'
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    if mcts:
        path = dir_path + f'mcts_longtail_stat.pdf'
        data = {'frequency':nums, 'counts':inc_counts, 'label':labels}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, False)
    else:
        path = dir_path + f'longtail_stat.pdf'
        data = {'frequency':nums, 'counts':inc_counts, 'label':labels}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_line(data, path, False)


def timeline_stat(data_dic):
    nums = []
    counts = []
    labels = []
    for name, data in data_dic.items():
        for item in data:
            corrects = []
            for k in range(1, roll_num+1):
                if isinstance(item['response'][0], dict):
                    if len(item['response'][0]) == 3:
                        score = [item['response'][i]['reward'] for i in range(k)]
                    else:
                        score = [item['response'][i]['score'] for i in range(k)]
                    idx = np.argmax(np.array(score)) 
                    cor_flags = item['corrects']
                    corrects.append(int(cor_flags[idx]))
                else:
                    answer = item['answer'][:k]
                    pred = max(answer, key=answer.count)
                    corrects.append(int(is_equiv(pred, item['label'], dataset)))                
            timeline = []
            cnt = 0
            for k in range(1, roll_num):
                if not corrects[k] and corrects[k-1]:
                    cnt += 1
                timeline.append(cnt)
            counts += timeline
            nums += list(range(2, roll_num+1))
            labels += [name] * len(timeline)
            
    dir_path = f'fig/{dataset}/{model}/'
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    path = dir_path + f'timeline_stat.pdf'
    data = {'N':nums, 'counts':counts, 'label':labels}
    data = pd.DataFrame(data, columns=list(data.keys()))
    draw_line(data, path, False)


          
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='Qwen2_5_3b_chat')
    parser.add_argument('--n_samples', type=int, default=500)
    parser.add_argument('--roll_num', type=int, default=128)
    parser.add_argument('--dataset', type=str, default='math')
    parser.add_argument('--width', type=int, default=5)
    parser.add_argument('--stat', type=str, default=None)
    args = parser.parse_args()
    
    model = args.model
    n_samples = args.n_samples
    roll_num = args.roll_num
    dataset = args.dataset
    width = args.width
    stat = args.stat
    
    # datasets = ['gsm8k', 'math', 'proofwriter', 'folio','siqa', 'wino', 'csqa', 'aqua', 'prontoqa']
    # rewards = ['skywork', 'shepherd', 'armorm', 'skyworko1', model]
    data_dic = {}
# for model in models:
    # mcts_node_stat(data_dic)
   
    # reward_with_temperature(data_dic)
    # distribution_with_temperature(data_dic)
    # mcts_stat(data_dic, method='q_value')
    # roll_num_timeline(data_dic, 32)
    # mcts_reward_distribution(data_dic)
    # acc_with_nums(data_dic)
    # reward_distribution(data_dic, roll_num=32, reward=['skywork'])
    # mcts_reward_acc(data_dic)
    # data_dic = {}
    # for method in ['sc', 'reward', 'q_value', 'bestN', 'bestQ']:
    #     path = f'./result/math/Qwen2_5_3b_chat/mcts16_t0.7_d5_w0.1_5_1_skyworko1_{method}_e3_200.json'
    #     data = load_json_data(path)[:-1]
    #     data_dic[method] = data
    # data_dic = {'mcts_skywork':mcts1_data, 'mcts_skyworko1':mcts2_data, 'mcts_sc':mcts3_data}
    # mcts_select_acc(data_dic, fig_name='prm_mcts')
    # mcts_acc_with_nums(data_dic)
    if stat == 'mcts_select':
        mcts_path_dic = {'orm':f'w5.0_{width}_1_skywork', 'prm':f'w0.1_{width}_1_skyworko1', 'sc':f'w0.1_3_5_self-{model}'}
  
        for reward in ['orm', 'prm', 'sc']:
            path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_{mcts_path_dic[reward]}_reward_e3_{n_samples}.json'
            data = load_json_data(path)[:-1]
            data_dic = {}
            for name in ['reward', 'sc', 'N_greedy', 'Q_greedy']:
                data_dic[name] = data 
            mcts_select_acc(data_dic,f'{reward}')
    elif stat == 'mcts_reward':
        mcts_path_dic = {'orm':f'w5.0_{width}_1_skywork', 'prm':f'w0.1_{width}_1_skyworko1', 'sc':f'w0.1_3_5_self-{model}'}
        data_dic = {}
        for reward in ['orm', 'prm', 'sc']:
            path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_{mcts_path_dic[reward]}_reward_e3_{n_samples}.json'
            data = load_json_data(path)[:-1]
            data_dic[reward] = data
        mcts_reward_acc(data_dic)
    elif stat == 'search_acc':
        data_dic = {}
        orm_path = f'./result/{dataset}/{model}/best128_skywork_e3_{n_samples}.json'
        prm_path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_{n_samples}.json'
        # sc_path = f'./result/{dataset}/{model}/sc128_e3_{n_samples}.json'
        orm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_5_1_skywork_reward_e3_{n_samples}.json'
        prm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_{n_samples}.json'
        # sc_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_3_5_self-{model}_reward_e3_{n_samples}.json'
        # data_dic['sc'] = load_json_data(sc_mcts_path)[:-1]
        data_dic['MCTS_ORM'] = load_json_data(orm_mcts_path)[:-1]
        data_dic['MCTS_PRM'] = load_json_data(prm_mcts_path)[:-1]
        
        acc_with_nums(data_dic, recall=True)
    elif stat == 'longtail_stat':
        data_dic = {}
        rewards = ['skywork', 'shepherd', 'armorm', 'skyworko1']
        for reward in rewards:
            # if model == 'Llama3_1_8b_chat':
            #     path = f'./result/{dataset}/{model}/best100_{reward}_e3_{n_samples}.json'
            # else:
            path = f'./result/{dataset}/{model}/best128_t0.7_{reward}_e3_500.json'
            data_dic[reward] = load_json_data(path)[:n_samples]
        longtail_stat(data_dic)
    elif stat == 'early_exp':
        data_dic = {}
        models = ['Gemma2-9B', 'Llama3.1-8B', 'Qwen2.5-3B', 'Qwen2.5-14B']
        rewards = ['Shepherd', 'ArmoRM', 'Skywork', 'Skyworko1']
        model_path_map = {
            'Gemma2-9B': 'Gemma2_9b_chat',
            'Llama3.1-8B': 'Llama3_1_8b_chat',
            'Qwen2.5-3B': 'Qwen2_5_3b_chat',
            'Qwen2.5-14B': 'Qwen2_5_14b_chat'
        }
        for model in models:
            
            sc_path = f'./result/math/{model_path_map[model]}/sc10_e3_500.json'
            data_dic[model] = {'SC':load_json_data(sc_path)[:-1]}
            for reward in rewards:
                if reward.endswith('o1'):
                    path = f'./result/math/{model_path_map[model]}/best10_t0.7_{reward.lower()}_e3_500.json'
                else:
                    path = f'./result/math/{model_path_map[model]}/best10_{reward.lower()}_e3_500.json'
                data_dic[model][reward] = load_json_data(path)[:-1]
        acc_with_models(data_dic, 'math')
    elif stat == 'sample_acc':
        data_dic = {}
        orm_path = f'./result/{dataset}/{model}/best128_t0.7_skywork_e3_500.json'
        prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_500.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_500.json'
        data_dic['SC'] = load_json_data(sc_path)[:n_samples]
        data_dic['ORM'] = load_json_data(orm_path)[:n_samples]
        data_dic['PRM'] = load_json_data(prm_path)[:n_samples]
        
        acc_with_nums(data_dic, recall=True)
    elif stat == 'rollnum_timeline':
        data_dic = {}
        index_dic = {}
        orm_path = f'./result/{dataset}/{model}/best128_t0.7_skywork_e3_{n_samples}.json'
        prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_e3_{n_samples}.json'
        # data_dic['sc'] = load_json_data(sc_path)[:-1]
        # index_dic['sc'] = range(50)
        data_dic['orm'] = load_json_data(orm_path)[:-1]
        index_dic['orm'] = [14,21,23,27,39]
        data_dic['prm'] = load_json_data(prm_path)[:-1]
        index_dic['prm'] = [0,17,27,38,45]
        rollnum_timeline(data_dic, index_dic)
    elif stat == 'temp_acc':
        data_dic = defaultdict(dict)
        for method in ['SC', 'ORM', 'PRM']:
            for t in [0.4, 0.7, 0.9, 1.2, 1.5]:
                if method == 'ORM':
                    path = f'./result/{dataset}/{model}/best{roll_num}_t{t}_skywork_e3_{n_samples}.json'
                elif method == 'SC':
                    path = f'./result/{dataset}/{model}/sc{roll_num}_t{t}_e3_{n_samples}.json'
                else:
                    path = f'./result/{dataset}/{model}/best{roll_num}_t{t}_skyworko1_e3_{n_samples}.json'
                data = load_json_data(path)[:-1]
                data_dic[t][method] = data
        acc_with_temperature(data_dic)
    elif stat == 'temp_sample':
        data_dic = defaultdict(dict)
        for model in ['Qwen2.5-3B', 'Llama3.1-8B']:
            for t in [0.4, 0.7, 0.9, 1.2, 1.5]:
                if model.startswith('Llama'):
                    path = f'./result/{dataset}/Llama3_1_8b_chat/sc{roll_num}_t{t}_e3_{n_samples}.json'
                else:
                    path = f'./result/{dataset}/Qwen2_5_3b_chat/sc{roll_num}_t{t}_e3_{n_samples}.json'
                # if method == 'orm':
                #     path = f'./result/{dataset}/{model}/best{roll_num}_t{t}_skywork_e3_{n_samples}.json'
                # else:
                #     path = f'./result/{dataset}/{model}/best{roll_num}_t{t}_skyworko1_e3_{n_samples}_old.json'
                data = load_json_data(path)[:-1]
                data_dic[model][t] = data
            # print(data_dic)
        sample_with_temperature(data_dic)
    elif stat == 'search_recall':
        data_dic = {}
        sc_path = f'./result/{dataset}/{model}/sc128_e3_{n_samples}.json'
        orm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d8_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        prm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d8_w0.1_{width}_1_skyworko1_reward_e3_{n_samples}.json'
        sc_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d8_w0.1_3_5_self-{model}_reward_e3_{n_samples}.json'
        data_dic['sc'] = load_json_data(sc_path)[:-1]
        data_dic['sc_mcts'] = load_json_data(sc_mcts_path)[:-1]
        data_dic['orm_mcts'] = load_json_data(orm_mcts_path)[:-1]
        data_dic['prm_mcts'] = load_json_data(prm_mcts_path)[:-1]
        # print(data_dic)
        acc_with_nums(data_dic, None, True)
    elif stat == 'mcts_longtail_stat':
        data_dic = {}
        # orm_path = f'./result/{dataset}/{model}/best128_skywork_e3_{n_samples}.json'
        # prm_path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_e3_{n_samples}.json'
        orm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d8_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        prm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d8_w0.1_{width}_1_skyworko1_reward_e3_{n_samples}.json'
        sc_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_3_5_self-{model}_reward_e3_{n_samples}.json'
        # data_dic['sc'] = load_json_data(sc_mcts_path)[:-1]
        # data_dic['orm'] = load_json_data(orm_path)[:-1]
        # data_dic['prm'] = load_json_data(prm_path)[:-1]
        data_dic['ORM'] = load_json_data(orm_mcts_path)[:-1]
        data_dic['PRM'] = load_json_data(prm_mcts_path)[:-1]
        
        longtail_stat(data_dic, mcts=True)
    elif stat == 'mcts_reward_comp':
        data_dic = {}
        orm_path = f'./result/{dataset}/{model}/best128_skywork_e3_{n_samples}.json'
        prm_path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_{n_samples}.json'
        
        # sc_path = f'./result/{dataset}/{model}/sc128_e3_{n_samples}.json'
        orm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        prm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_{width}_1_skyworko1_reward_e3_{n_samples}.json'
        sc_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_3_5_self-{model}_reward_e3_{n_samples}.json'
        # data_dic['direct'] = load_json_data(orm_path)[:-1]
        # data_dic['orm'] = load_json_data(orm_path)[:-1]
        data_dic['direct'] = load_json_data(prm_path)[:-1]
        # data_dic['mcts'] = load_json_data(orm_mcts_path)[:-1]
        data_dic['mcts'] = load_json_data(prm_mcts_path)[:-1]
        
        mcts_reward_acc(data_dic, comp='prm')
    elif stat == 'mcts_shape_acc':
        data_dic = {}
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_5_1_skywork_reward_e3_{n_samples}.json'
        for width in ['3', '5', '8']:
            new_mcts_path = mcts_path.replace('w5.0_5', f'w5.0_{width}')
            data_dic[f'width={width}'] = load_json_data(new_mcts_path)[:-1]
        acc_with_nums(data_dic, fig_name='orm_mcts_width')
        data_dic = {}
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_{n_samples}.json'
        for width in ['3', '5', '8']:
            new_mcts_path = mcts_path.replace('w0.1_5', f'w0.1_{width}')
            data_dic[f'width={width}'] = load_json_data(new_mcts_path)[:-1]
        acc_with_nums(data_dic, fig_name='prm_mcts_width')
        data_dic = {}
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_5_1_skywork_reward_e3_{n_samples}.json'
        for depth in ['3', '5', '8']:
            new_mcts_path = mcts_path.replace('d5', f'd{depth}')
            data_dic[f'depth={depth}'] = load_json_data(new_mcts_path)[:-1]
        acc_with_nums(data_dic, fig_name='orm_mcts_depth')
        data_dic = {}
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_{n_samples}.json'
        for depth in ['3', '5', '8']:
            new_mcts_path = mcts_path.replace('d5', f'd{depth}')
            data_dic[f'depth={depth}'] = load_json_data(new_mcts_path)[:-1]
        acc_with_nums(data_dic, fig_name='prm_mcts_depth')

    elif stat == 'mcts_shape_reward':
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_5_1_skywork_reward_e3_{n_samples}.json'
        data_dic['w5_d5'] = load_json_data(mcts_path)[:-1]
        for width in ['3', '8']:
            new_mcts_path = mcts_path.replace('w5.0_5', f'w5.0_{width}')
            data_dic[f'w{width}_d5'] = load_json_data(new_mcts_path)[:-1]
        for depth in ['3', '8']:
            new_mcts_path = mcts_path.replace('d5', f'd{depth}')
            data_dic[f'w5_d{depth}'] = load_json_data(new_mcts_path)[:-1]
        mcts_reward_acc(data_dic, comp='shape')
    elif stat == 'mcts_explore_acc':
        data_dic = {}
        for w in [ 'c=0.1', 'c=1.0', 'c=10.0']:
            path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w{w.split("=")[-1]}_3_1_skyworko1_reward_e3_{n_samples}.json'
            data_dic[w] = load_json_data(path)[:-1]
        acc_with_nums(data_dic, fig_name='prm_explore')
        data_dic = {}
        map = {'c=0.1':'5.0', 'c=1.0':'50.0', 'c=10.0':'500.0'}
        for w in ['c=0.1', 'c=1.0', 'c=10.0']:
            path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w{map[w]}_3_1_skywork_reward_e3_{n_samples}.json'
            data_dic[w] = load_json_data(path)[:-1]
        acc_with_nums(data_dic, fig_name='orm_explore')
    elif stat == 'mcts_inc_score':
        data_dic = {}
        orm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        prm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_{width}_1_skyworko1_reward_e3_{n_samples}.json'
        sc_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_3_5_self-{model}_reward_e3_{n_samples}.json'
        data_dic['sc'] = load_json_data(sc_mcts_path)[:-1]
        data_dic['orm'] = load_json_data(orm_mcts_path)[:-1]
        data_dic['prm'] = load_json_data(prm_mcts_path)[:-1]
        
        mcts_reward_incorrect(data_dic)
        
    elif stat == 'bon_dif_acc':
        data_dic = {}
        orm_path = f'./result/{dataset}/{model}/best128_t0.7_skywork_e3_{n_samples}.json'
        prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_{n_samples}.json'
        data_dic['SC'] = load_json_data(sc_path)[:-1]
        data_dic['ORM'] = load_json_data(orm_path)[:-1]
        data_dic['PRM'] = load_json_data(prm_path)[:-1]
        
        acc_with_diff(data_dic, fig_name='bon')
    elif stat == 'mcts_dif_acc':
        data_dic = {}
        bon_path = f'./result/{dataset}/{model}/best128_t0.7_skywork_e3_500.json'
        # prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_500.json'
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        data_dic['SC'] = load_json_data(sc_path)[:n_samples]
        data_dic['BoN'] = load_json_data(bon_path)[:n_samples]
        data_dic['MCTS-SC'] = load_json_data(mcts_path)[:n_samples]
        data_dic['MCTS-RM'] = load_json_data(mcts_path)[:n_samples]
        
        acc_with_diff(data_dic, fig_name='orm_mcts')
        
        data_dic = {}
        bon_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_500.json'
        # prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_{n_samples}.json'
        # sc_path = f'./result/{dataset}/{model}/sc128_e3_500.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_500.json'
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_{width}_1_skyworko1_reward_e3_{n_samples}.json'
        data_dic['SC'] = load_json_data(sc_path)[:n_samples]
        data_dic['BoN'] = load_json_data(bon_path)[:n_samples]
        data_dic['MCTS-SC'] = load_json_data(mcts_path)[:n_samples]
        data_dic['MCTS-RM'] = load_json_data(mcts_path)[:n_samples]
        
        acc_with_diff(data_dic, fig_name='prm_mcts')
    elif stat == 'llama_dif_acc':
        data_dic = {}
        bon_path = f'./result/{dataset}/{model}/best128_t0.7_skywork_e3_500.json'
        # prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_500.json'
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        data_dic['SC'] = load_json_data(sc_path)[:n_samples]
        data_dic['BoN'] = load_json_data(bon_path)[:n_samples]
        data_dic['MCTS_SC'] = load_json_data(mcts_path)[:n_samples]
        data_dic['MCTS_Reward'] = load_json_data(mcts_path)[:n_samples]
        
        acc_with_diff(data_dic, fig_name='llama_orm')
        
        data_dic = {}
        bon_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_200.json'
        # prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_500.json'
        mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_{width}_1_skyworko1_reward_e3_{n_samples}.json'
        data_dic['SC'] = load_json_data(sc_path)[:n_samples]
        data_dic['BoN'] = load_json_data(bon_path)[:n_samples]
        data_dic['MCTS_SC'] = load_json_data(mcts_path)[:n_samples]
        data_dic['MCTS_Reward'] = load_json_data(mcts_path)[:n_samples]
        
        acc_with_diff(data_dic, fig_name='llama_prm')
    elif stat == 'timeline_stat':
        data_dic = {}
        orm_path = f'./result/{dataset}/{model}/best128_t0.7_skywork_e3_500.json'
        prm_path = f'./result/{dataset}/{model}/best128_t0.7_skyworko1_e3_500.json'
        sc_path = f'./result/{dataset}/{model}/sc128_e3_500.json'
        # mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        data_dic['SC'] = load_json_data(sc_path)[:n_samples]
        data_dic['ORM'] = load_json_data(orm_path)[:n_samples]
        data_dic['PRM'] = load_json_data(prm_path)[:n_samples]
        orm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_5_1_skywork_reward_e3_{n_samples}.json'
        prm_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_{n_samples}.json'
        # sc_mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w0.1_3_5_self-{model}_reward_e3_{n_samples}.json'
        # data_dic['sc'] = load_json_data(sc_mcts_path)[:-1]
        data_dic['MCTS_ORM'] = load_json_data(orm_mcts_path)[:-1]
        data_dic['MCTS_PRM'] = load_json_data(prm_mcts_path)[:-1]
        
        # data_dic['MCTS_Reward'] = load_json_data(mcts_path)[:n_samples]
        timeline_stat(data_dic)
    elif stat == 'dataset_stat':
        data_dic = defaultdict(dict)
        # for dataset in ['math']:
        orm_path = f'./result/{dataset}/{model}/best32_t0.7_skywork_e3_{n_samples}.json'
        prm_path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_{n_samples}.json'
        sc_path = f'./result/{dataset}/{model}/sc128_t0.7_e3_{n_samples}.json'
        # mcts_path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_w5.0_{width}_1_skywork_reward_e3_{n_samples}.json'
        data_dic[dataset]['sc'] = load_json_data(sc_path)[:-1]
        data_dic[dataset]['orm'] = load_json_data(orm_path)[:-1]
        data_dic[dataset]['prm'] = load_json_data(prm_path)[:-1]
      
        
        # data_dic['MCTS_Reward'] = load_json_data(mcts_path)[:n_samples]
        dataset_dif_stat(data_dic)
    elif stat == 'main_result':
        data_dic = defaultdict(dict)
        for dataset in ['olympiadbench']:
            for model in ['Qwen2_5_3b_chat']:
                for method in ['cot', 'sc', 'bon_orm', 'orm_sc', 'mcts_orm', 'woc_orm', 'bon_prm', 'prm_sc', 'mcts_prm', 'beam', 'woc_prm']:
                    if method == 'cot':
                        path = f'./result/{dataset}/{model}/cot_t0.7_e3_500.json'
                    elif method == 'sc':
                        path = f'./result/{dataset}/{model}/sc32_t0.7_e3_500.json'
                    elif method == 'bon_orm':
                        path = f'./result/{dataset}/{model}/best32_t0.7_skywork_e3_500.json'
                    elif method == 'orm_sc':
                        path = f'./result/{dataset}/{model}/sc32_t0.7_skywork_e3_500.json'
                    elif method == 'mcts_orm':
                        path = f'./result/{dataset}/{model}/mcts16_t0.7_d5_w5.0_5_1_skywork_reward_e3_500.json'
                    elif method == 'woc_orm':
                        path = f'./result/{dataset}/{model}/woc32_w1_d3_skywork_e3_500.json'
                    elif method == 'bon_prm':
                        path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_500.json'
                    elif method == 'prm_sc':
                        path = f'./result/{dataset}/{model}/sc32_t0.7_skyworko1_e3_500.json'
                    elif method == 'mcts_prm':
                        path = f'./result/{dataset}/{model}/mcts16_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_500.json'
                    elif method == 'beam':
                        path = f'./result/{dataset}/{model}/beam8_t0.7_d5_3_skyworko1_e3_500.json'
                    elif method == 'woc_prm':
                        path = f'./result/{dataset}/{model}/woc32_w2_d3_skyworko1_e3_500.json'

                    if os.path.exists(path):
                        # print(load_json_data(path)[:n_samples])
                        acc = [item['cor_flag'] for item in load_json_data(path)[:n_samples]].count(True) / n_samples
                    else:
                        path = path.replace('500', '200')
                        if os.path.exists(path):
                            acc = load_json_data(path)[-1]['acc']
                        else:
                            acc = -1
                    print(f'{dataset}\t{model}\t{method}\tAcc:{acc}')
    elif stat == 'append_dataset':
        data_dic = {}
        datasets = ['prontoqa', 'proofwriter', 'wino', 'csqa', 'aqua']
        models = ['Llama2-13B', 'Mistral-7B', 'Gemma2-9B', 'Llama3.1-8B', 'Qwen2.5-3B', 'Qwen2.5-14B']
        rewards = ['Shepherd', 'ArmoRM', 'Skywork']
        model_path_map = {
            'Llama2-13B': 'Llama2_13b_chat',
            'Mistral-7B': 'Mistral_7b_chat',
            'Gemma2-9B': 'Gemma2_9b_chat',
            'Llama3.1-8B': 'Llama3_1_8b_chat',
            'Qwen2.5-3B': 'Qwen2_5_3b_chat',
            'Qwen2.5-14B': 'Qwen2_5_14b_chat'
        }
        for dataset in datasets:
            for model in models:
                
                sc_path = f'./result/{dataset}/{model_path_map[model]}/sc10_e3_500.json'
                data_dic[model] = {'SC':load_json_data(sc_path)[:-1]}
                for reward in rewards:
                    if reward.endswith('o1'):
                        path = f'./result/{dataset}/{model_path_map[model]}/best10_t0.7_{reward.lower()}_e3_500.json'
                    else:
                        path = f'./result/{dataset}/{model_path_map[model]}/best10_{reward.lower()}_e3_500.json'
                    if not os.path.exists(path):
                        print(path)
                    data_dic[model][reward] = load_json_data(path)[:-1]
            acc_with_models(data_dic, dataset)
    elif stat == 'append_diff_est':
        data_dic = {}
        models = ['Llama2-13B', 'Mistral-7B', 'Llama3.1-8B', 'Qwen2.5-3B', 'Qwen2.5-14B']
        model_path_map = {
            'Llama2-13B': 'Llama2_13b_chat',
            'Mistral-7B': 'Mistral_7b_chat',
            'Gemma2-9B': 'Gemma2_9b_chat',
            'Llama3.1-8B': 'Llama3_1_8b_chat',
            'Qwen2.5-3B': 'Qwen2_5_3b_chat',
            'Qwen2.5-14B': 'Qwen2_5_14b_chat'
        }
        types = ['Length', 'Count', 'Null']
    
        for type in types:
            difficulty_ls = []
            model_ls = []
            scores = []
            for model in models:
                difficulty_dic = split_difficulty(dataset, model=model_path_map[model])
                sc_path = f'./result/{dataset}/{model_path_map[model]}/sc{roll_num}_e3_{n_samples}.json'
                sc_result = load_json_data(sc_path)[:-1]
                for difficulty, index in difficulty_dic.items():
                    result = [item for item in sc_result if item['id'] in index]
                    if type == 'Divergence':
                        modelwrapper = ModelWrapper(model_path_map[model])
                        tokenizer = modelwrapper.tokenizer
                        lm_model = modelwrapper.model       
                    for item in tqdm(result):
                        if type == 'Divergence':
                            responses = item['response']
                            js_divergences = compute_js_divergence_matrix(responses, tokenizer, lm_model)
                            score = np.mean(np.array(js_divergences))
                        elif type == 'Length':
                            responses = item['response']
                            score = np.mean(np.array([len(res) for res in responses]))
                        elif type == 'Null':
                            answers = item['answer']
                            score = answers.count(None)
                        else:
                            answers = item['answer']
                            score = len(set(answers))
                        difficulty_ls.append(difficulty)
                        model_ls.append(model)
                        scores.append(score)
                    if type == 'Divergence':
                        del modelwrapper
                        del tokenizer
                        del lm_model 
            dir_path = f'fig/{dataset}/'
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)    
            path = dir_path + f'{type}_diff.pdf'
            data = {'difficulty':difficulty_ls, f'{type}':scores, 'models':model_ls}
            data = pd.DataFrame(data, columns=list(data.keys()))
            draw_bar(data, path)
    elif stat == 'append_mcts_early':
        mcts_path_dic = {'Skywork':'w5.0_5_1_skywork', 'Skyworko1':'w0.1_5_1_skyworko1', 'ArmoRM':'w0.1_5_1_armorm', 'Self':f'w0.1_3_5_self-{model}'}
        rewards = ['Self', 'ArmoRM', 'Skywork', 'Skyworko1']
        data_dic = defaultdict(dict)
        methods = ['SC',  'Reward', 'Maj_vote', 'Q_value', 'N_greedy', 'Q_greedy', 'Oracle']
        for rm in rewards:
            for method in methods:
                if method == 'SC':
                    sc_path = f'./result/math/{model}/sc128_e3_{n_samples}.json'
                    data_dic[rm]['SC'] = load_json_data(sc_path)[:-1]
                else:
                    path = f'./result/{dataset}/{model}/mcts{roll_num}_t0.7_d5_{mcts_path_dic[rm]}_reward_e3_{n_samples}.json'
                    data_dic[rm][method] = load_json_data(path)[:-1]
        
        methods = []
        model_ls = []
        scores = []
        for rm, data in data_dic.items():
            for name, result in data.items():
                model_ls.append(rm)
                methods.append(name)
                if name == 'SC':
                    scores.append(get_roll_n_acc(result, 16, False))
                else:
                    scores.append(append_mcts_acc(result, name))
        
        dir =  f'fig/{dataset}/{model}/'
        if not os.path.exists(dir):
            os.makedirs(dir)
        path = dir + 'append_mcts_early.pdf'
        data = {'reward model':model_ls, 'accuracy':scores, 'method':methods}
        data = pd.DataFrame(data, columns=list(data.keys()))
        draw_bar(data, path, True)
    elif stat == 'append_olympiad':
        def get_acc(result:list[dict], n_samples:int, method:str) -> float:
            cor_flag = 0
            target_index = range(0, n_samples)
            for item in result:
                if n_samples > len(item['response']):
                    index = range(0, len(item['response']))
                else:
                    index = target_index
                responses = [item['response'][idx] for idx in index]
                answers = [item['answer'][idx] for idx in index]
                corrects = [item['corrects'][idx] for idx in index]
                if 'sc' in method.lower():
                    best_idx = answers.index(max(answers, key=answers.count))
                elif 'mcts' in method.lower():
                    best_idx = max(enumerate(responses), key=lambda x: x[1]['reward'])[0]
                elif 'ours' in method.lower():
                    score = [item['trace'][str(idx+1)]['Q'] for idx in index]
                    # print(score)
                    if all(element is None for element in score):
                        continue
                    mean_score = np.mean(np.array([s for s in score if s]))
                    std_score = np.std(np.array([s for s in score if s]))
                    score_dic = defaultdict(float)
                    for i in range(len(score)):
                        if not answers[i] or not score[i]:
                            # print('>>>')
                            continue
                        score_dic[answers[i]] += (score[i] - mean_score) / std_score
                    if not score_dic:
                        best_idx = 0
                    else:
                        good_ans = sorted(score_dic.items(), key=lambda x:x[1], reverse=True)[0][0] 
                        # print(good_ans)
                        best_idx = [idx for idx in index if answers[idx] == good_ans][0]             
                else:
                    best_idx = max(enumerate(responses), key=lambda x: x[1]['score'])[0]
                
                cor_flag += int(corrects[best_idx])
            return cor_flag / len(result)
            
        
        dataset = 'olympiadbench'
        model = 'Qwen2_5_3b_chat'
        for rm in ['orm', 'prm']:
            method_ls = []
            num_ls = []
            score_ls = []
            if rm == 'orm':
                methods = ['SC', 'BoN','BoN-Weighted', 'MCTS', 'Ours']
            else:
                methods = ['SC', 'BoN', 'MCTS', 'Beam', 'Ours']
            for method in methods:
                if method == 'CoT':
                    path = f'./result/{dataset}/{model}/cot_t0.7_e3_200.json'
                elif method == 'SC':
                    path = f'./result/{dataset}/{model}/sc32_t0.7_e3_200.json'
                elif method == 'BoN':
                    if rm == 'orm':
                        path = f'./result/{dataset}/{model}/best32_t0.7_skywork_e3_200.json'
                    else:
                        path = f'./result/{dataset}/{model}/best32_t0.7_skyworko1_e3_200.json'
                elif method == 'MCTS':
                    if rm == 'orm':
                        path = f'./result/{dataset}/{model}/mcts16_t0.7_d5_w5.0_5_1_skywork_reward_e3_200.json'
                    else:
                        path = f'./result/{dataset}/{model}/mcts16_t0.7_d5_w0.1_5_1_skyworko1_reward_e3_200.json'
                elif method == 'Ours':
                    if rm == 'orm':
                        path = f'./result/{dataset}/{model}/woc32_w1_d3_skywork_e3_200.json'
                    else:
                        path = f'./result/{dataset}/{model}/woc32_w2_d3_skyworko1_e3_200.json'
                elif method == 'Beam':
                    path = f'./result/{dataset}/{model}/beam8_t0.7_d5_3_skyworko1_e3_200.json'
                else:
                    if rm == 'orm':
                        path = f'./result/{dataset}/{model}/sc32_t0.7_skywork_e3_200.json'
                    else:
                        path = f'./result/{dataset}/{model}/sc32_t0.7_skyworko1_e3_200.json'
              
                acc = load_json_data(path)[-1]['acc']
                score_ls.append(acc)
                method_ls.append(method)
            data = {'method':method_ls, 'accuracy':score_ls,}
            data = pd.DataFrame(data, columns=['method', 'accuracy'])
            dir_path = f'fig/{dataset}/{model}/'
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)
            path =  dir_path + f'olympiad_{rm}.pdf'
            draw_bar(data, path)  
                
    # elif stat == 'append_ablation':
        