import os, sys
from tqdm import tqdm, trange
# os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["VSCODE_PROXY_CUDA_DEVICE"] # FIXME: remove this line in sbatch script
current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from eval import eval_file, eval_file_multi_shots_best

def eval_model_math_by_category(model, method_list, subset_list, hard_level_list=["Level 5"], is_train=False):
    
    scores = {} # {method: {data_type: {hard_level: score}}}
    for method in method_list:
        scores[method] = {}
        scores[method]['average'] = []
        for data_type in subset_list:
            scores[method][data_type] = []
            for hard_level in hard_level_list:

                if 'gpt' in model:
                    stored_path = os.path.join(parent_dir, f"new_results/gpt/{model}_{data_type}_{hard_level}_{method}_greedy.jsonl")
                else:
                    stored_path = os.path.join(parent_dir, f"new_results/{model}_{data_type}_{hard_level}_{method}_greedy.jsonl")            
                if is_train:
                    stored_path = stored_path.replace('.jsonl', '_train.jsonl')

                scores[method][data_type].extend(eval_file(stored_path))
                scores[method]['average'].extend(eval_file(stored_path))
        # print(sum(total_score))
        # average_score_ = sum(total_score) / len(total_score)
        # scores[method]['average'] = average_score_

    if model == 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo':
        model = 'llama-3.1-8B'
    elif model == 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo':
        model = 'llama-3.1-70B'
    else:
        pass
    lvls = ''.join(hard_level_list).replace('Level ', '')
    output_name = f"{model}_math_test_lvl{lvls}_report_by_category.csv"
    if is_train:
        output_name = output_name.replace('_test_', '_train_')
    with open(os.path.join(current_dir, output_name), 'w') as f:
        f.write('method,'+','.join(subset_list)+',average\n')
        for method in method_list:
            f.write(method)
            for data_type in subset_list:
                average_score = sum(scores[method][data_type]) / len(scores[method][data_type])
                f.write(f',{average_score*100:.2f}')
            average_score_ = sum(scores[method]['average']) / len(scores[method]['average'])
            f.write(f',{average_score_*100:.2f}')
            f.write('\n')
    
    return scores

def eval_model_math_by_hard_level(model, method_list, subset_list, hard_level_list=["Level 5"], is_train=False):
    
    scores = {} # {method: {data_type: {hard_level: score}}}
    for method in method_list:
        scores[method] = {}
        scores[method]['average'] = []
        for hard_level in hard_level_list:
            scores[method][hard_level] = []
            for data_type in subset_list:

                if 'gpt' in model:
                    stored_path = os.path.join(parent_dir, f"new_results/gpt/{model}_{data_type}_{hard_level}_{method}_greedy.jsonl")
                else:
                    stored_path = os.path.join(parent_dir, f"new_results/{model}_{data_type}_{hard_level}_{method}_greedy.jsonl")            
                if is_train:
                    stored_path = stored_path.replace('.jsonl', '_train.jsonl')

                scores[method][hard_level].extend(eval_file(stored_path))
                scores[method]['average'].extend(eval_file(stored_path))
        # average_score_ = sum(total_score) / len(total_score)
        # scores[method]['average'] = average_score_

    if model == 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo':
        model = 'llama-3.1-8B'
    elif model == 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo':
        model = 'llama-3.1-70B'
    else:
        pass
    lvls = ''.join(hard_level_list).replace('Level ', '')
    output_name = f"{model}_math_test_lvl{lvls}_report_by_hard_level.csv"
    if is_train:
        output_name = output_name.replace('_test_', '_train_')
    with open(os.path.join(current_dir, output_name), 'w') as f:
        f.write('method,'+','.join(hard_level_list)+'\n')
        for method in method_list:
            f.write(method)
            for hard_level in hard_level_list:
                average_score = sum(scores[method][hard_level]) / len(scores[method][hard_level])
                f.write(f',{average_score*100:.2f}')
            average_score_ = sum(scores[method]['average']) / len(scores[method]['average'])
            f.write(f',{average_score_*100:.2f}')
            f.write('\n')

def eval_model_math_by_category_nshots(model, method_list, subset_list, nshots=4, hard_level_list=["Level 5"], is_train=False):
    
    scores = {} # {method: {data_type: {hard_level: score}}}
    for method in method_list:
        scores[method] = {}
        scores[method]['average'] = []
        for data_type in subset_list:
            scores[method][data_type] = []
            for hard_level in hard_level_list:

                if 'gpt' in model:
                    stored_path = os.path.join(parent_dir, f"new_results/gpt/{model}_{data_type}_{hard_level}_{method}_nshots_{nshots}.jsonl")
                else:
                    stored_path = os.path.join(parent_dir, f"new_results/{model}_{data_type}_{hard_level}_{method}_nshots_{nshots}.jsonl")      
                if is_train:
                    stored_path = stored_path.replace('.jsonl', '_train.jsonl')

                scores[method][data_type].extend(eval_file_multi_shots_best(stored_path, num_shots=nshots))
                scores[method]['average'].extend(eval_file_multi_shots_best(stored_path, num_shots=nshots))
        # print(sum(total_score))
        # average_score_ = sum(total_score) / len(total_score)
        # scores[method]['average'] = average_score_

    if model == 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo':
        model = 'llama-3.1-8B'
    elif model == 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo':
        model = 'llama-3.1-70B'
    else:
        pass
    lvls = ''.join(hard_level_list).replace('Level ', '')
    output_name = f"{model}_math_test_lvl{lvls}_report_by_category.csv"
    if is_train:
        output_name = output_name.replace('_test_', '_train_')
    with open(os.path.join(current_dir, output_name), 'w') as f:
        f.write('method,'+','.join(subset_list)+',average\n')
        for method in method_list:
            f.write(method)
            for data_type in subset_list:
                average_score = sum(scores[method][data_type]) / len(scores[method][data_type])
                f.write(f',{average_score*100:.2f}')
            average_score_ = sum(scores[method]['average']) / len(scores[method]['average'])
            f.write(f',{average_score_*100:.2f}')
            f.write('\n')
    
    return scores

if __name__ == '__main__':
    # model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    # model = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
    # # method_list = ['cot', 'pal', 'codenl', 'nlcode','nlcode_single']
    # method_list = ['cot', 'pal', 'codenl', 'nlcode','nlcode_single','majorvote']
    subset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    # hard_level_list = ["Level 5", "Level 4", "Level 3", "Level 2", "Level 1"]
    # is_train = False
    # eval_model_math_by_category(model, method_list, subset_list, hard_level_list, is_train)
    # eval_model_math_by_hard_level(model, method_list, subset_list, hard_level_list, is_train)
    model = 'gpt-4o'
    method_list = ['cot']
    eval_model_math_by_category_nshots(model, method_list, subset_list=subset_list, nshots=4, hard_level_list=["Level 5"], is_train=False)