from utils.load_data import load_json_data, extract_answer, write_json_data
from utils.eval import is_equiv
import os
import numpy as np 
import pandas as pd
import argparse
from utils.metrics import draw_box

def split_difficulty(model, n_samples):
    difficulty_dic = {}
    sc_path = f'./result/{dataset}/{model}/sc10_e3_{n_samples}.json'
    sc_result = load_json_data(sc_path)[:-1]
    for item in sc_result:
        id = item['id']
        difficulty = 6 - item['corrects'].count(True) // 2
        if difficulty in difficulty_dic.keys():
            difficulty_dic[difficulty].append(id)
        else:
            difficulty_dic[difficulty] = [id]
    return difficulty_dic   



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='Llama3_1_8b_chat')
    parser.add_argument('--n_samples', type=int, default=200)
    parser.add_argument('--roll_num', type=int, default=100)
    parser.add_argument('--dataset', type=str, default='math')
    args = parser.parse_args()
    
    model = args.model
    n_samples = args.n_samples
    roll_num = args.roll_num
    dataset = args.dataset 
    
    # beam_path = f'./result/math/{model}/beam_e3_{n_samples}.json'
    bestn1_path = f'./result/{dataset}/{model}/best{roll_num}_skywork_e3_{n_samples}.json'
    bestn2_path = f'./result/{dataset}/{model}/best{roll_num}_shepherd_e3_{n_samples}.json'
    bestn3_path = f'./result/{dataset}/{model}/best{roll_num}_skyworko1_e3_{n_samples}.json'
    bestn4_path = f'./result/{dataset}/{model}/best{roll_num}_armorm_e3_{n_samples}.json'
    # mcts_path = f'./result/math/{model}/mcts_e3_{n_samples}.json'
    sc_path = f'./result/{dataset}/{model}/sc{roll_num}_e3_{n_samples}.json'
    # slm_path = f'./result/math/{model}/best{roll_num}_{model}_e3_{n_samples}.json'
    
    # beam_data = load_json_data(beam_path)[:-1]
    bestn1_data = load_json_data(bestn1_path)[:-1]
    bestn2_data = load_json_data(bestn2_path)[:-1]
    bestn3_data = load_json_data(bestn3_path)[:-1]
    bestn4_data = load_json_data(bestn4_path)[:-1]
    # mcts_data = load_json_data(mcts_path)[:-1]
    sc_data = load_json_data(sc_path)[:-1]
    # slm_data = load_json_data(slm_path)[:-1]
    
    diff_dic = split_difficulty(model, n_samples)
    
    for difficulty, index in diff_dic.items():
        # print(diff_)
        if difficulty <= 4:
            continue
        result = [item for item in bestn3_data if item['id'] in index]
        for item in result:
            # print(item)
            if not item['cor_flag']:
                print(item['id'])
        