from experiment import HardLevelExperiment
from eval import eval_file
from utils.utils import extract_hard_level_file
import argparse
import os

def extract_hard_level(model_name, dataset_list, method_list, level='Level 5'):
    """
    extract the level 5 problem and its results from the whole file
    """ 
    for method in method_list:
        for dataset in dataset_list:
            file = f'new_results/{model_name}_{dataset}_{method}.jsonl'
            extract_hard_level_file(file, level)
            print(f"Extracted {level} problems from {file}")

def Math_level(model_name, verbose, zero_shot, method_list=['cot','pal','codenl','nlcode'], is_train=False, hard_level='Level 5', num_shots=1):
    """
    This function is used to run the experiment on the level 5 problems of MATH dataset
    It will automatically check the results of exsitinng results and running the left experiments
    """

    dataset_list = ['algebra','counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']
    
    # verbose = False
    for method in method_list:
        scores = []
        for dataset in dataset_list:
            print(f"Running experiment on {dataset} with hard level {hard_level} for method {method}")

            if 'gpt' in model_name:
                store_path = f"results/gpt/{model_name}_{dataset}_{hard_level}_{method}_nshots_{num_shots}.jsonl"
            else:
                store_path = f"results/{model_name}_{dataset}_{hard_level}_{method}_nshots_{num_shots}.jsonl"
            if is_train:
                store_path = store_path.replace(".jsonl", "_train.jsonl")

            if os.path.exists(store_path):
                print(f"Results already exist in {store_path}")
                continue

            if method == 'incmath':
                reasoning_model = 'gpt-4o-mini'
                exp = HardLevelExperiment(model_name, dataset, method, store_path, hard_level=hard_level, verbose=verbose, zero_shot=zero_shot, reasoning_model=reasoning_model, is_train=is_train, num_shots=num_shots)
            else:
                exp = HardLevelExperiment(model_name, dataset, method, store_path, hard_level=hard_level, verbose=verbose, zero_shot=zero_shot, is_train=is_train, num_shots=num_shots)

            exp.run()
        if len(scores) == 0:
            continue
        accuracy = sum(scores) / len(scores)
        print(f"Average accuracy for model {model_name} of method {method} is {accuracy}")


def Math_level_greedy(model_name, verbose, zero_shot, method_list=['cot','pal','codenl','nlcode'], is_train=False, hard_level='Level 5'):
    """
    This function is used to run the experiment on the level 5 problems of MATH dataset
    It will automatically check the results of exsitinng results and running the left experiments
    """

    dataset_list = ['algebra','counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']
    
    # verbose = False
    for method in method_list:
        scores = []
        for dataset in dataset_list:
            print(f"Running experiment on {dataset} with hard level {hard_level} for method {method}")

            if 'gpt' in model_name:
                store_path = f"new_results/gpt/{model_name}_{dataset}_{hard_level}_{method}_greedy.jsonl"
            else:
                store_path = f"new_results/{model_name}_{dataset}_{hard_level}_{method}_greedy.jsonl"
            if is_train:
                store_path = store_path.replace(".jsonl", "_train.jsonl")

            if os.path.exists(store_path):
                print(f"Results already exist in {store_path}")
            else:
                exp = HardLevelExperiment(model_name, dataset, method, store_path, hard_level=hard_level, verbose=verbose, zero_shot=zero_shot, is_train=is_train, num_shots=num_shots)
                exp.run()
            # if method == 'incmath':
            #     reasoning_model = 'gpt-4o-mini'
            #     exp = HardLevelExperiment(model_name, dataset, method, store_path, hard_level=hard_level, verbose=verbose, zero_shot=zero_shot, reasoning_model=reasoning_model, is_train=is_train, num_shots=num_shots)
            # else:
            scores += eval_file(store_path)

        if len(scores) == 0:
            continue
        accuracy = sum(scores) / len(scores)
        print(f"Average accuracy for model {model_name} of method {method} is {accuracy}")

def Math_level_4_shots_cot(model_name, verbose, zero_shot, method='cot', is_train=False, hard_level='Level 5'):
    dataset_list = ['algebra','counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']
    scores = []
    num_shots = 4
    for dataset in dataset_list:
        print(f"Running experiment on {dataset} with hard level {hard_level} for method {method}")

        if 'gpt' in model_name:
            store_path = f"new_results/gpt/{model_name}_{dataset}_{hard_level}_{method}_nshots_{num_shots}.jsonl"
        else:
            store_path = f"new_results/{model_name}_{dataset}_{hard_level}_{method}_nshots_{num_shots}.jsonl"
        if is_train:
            store_path = store_path.replace(".jsonl", "_train.jsonl")

        if os.path.exists(store_path):
            print(f"Results already exist in {store_path}")
            continue

        exp = HardLevelExperiment(model_name, dataset, method, store_path, hard_level=hard_level, verbose=verbose, zero_shot=zero_shot, is_train=is_train, num_shots=num_shots)
        exp.run()
        # scores += eval_file(store_path)
    # accuracy = sum(scores) / len(scores)
    # print(f"Average accuracy for model {model_name} of method {method} is {accuracy}")

if __name__ == '__main__':
    # Create a hard level experiment
    # model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    # extract_hard_level(model_name, dataset_list, method_list)
    # dataset_list =  ['counting & probability']
    # method = 'metamath' # 'pal'
    # dataset_list = ['intermediate algebra', 'prealgebra']
    # model_name = 'simenghan/Meta-Llama-3.1-8B-Instruct-Reference-CodeInstruct-v1-7d136f8a-398e666e'
    # model_name = 'simenghan/Qwen2-1.5B-Instruct-INC-v1-be401a43'

    ############### set the model name here ###############
    # ft_model_list = ['simenghan/Qwen2-1.5B-Instruct-INC-v1-be401a43','simenghan/Qwen2-1.5B-Instruct-INC-v1-0f71efce', 'simenghan/Meta-Llama-3.1-8B-Instruct-Reference-INC-v1-2fd110a0', 'simenghan/Meta-Llama-3.1-8B-Instruct-Reference-INC-v1-41bfa545']
    # model_name = 'simenghan/Qwen2-1.5B-Instruct-INC-v1-be401a43'
    # Meta_LLama_3_1_model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    # model_name = 'gpt-4o-mini'
    
    # model_map = {
    #         "gpt-4o-mini": "gpt-4o-mini",
    #         "llama-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
    #         "llama-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    #         "gpt-4o": "gpt-4o"
    #     }
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--model', type=str, choices=model_map.keys())
    # parser.add_argument('--zero_shot', action='store_true')
    # parser.add_argument('--verbose', action='store_true')
    # parser.add_argument('--is_train', action='store_true')
    # parser.add_argument('--num_shots', type=int, default=1)
    # parser.add_argument('--hard_level_list', type=str, nargs='+', default=['5', '4', '3', '2', '1'])
    # # parser.add_argument('--method_list', type=str, nargs='+', default=['cot', 'pal', 'codenl_single', 'nlcode_single', 'codenl', 'nlcode'])
    # parser.add_argument('--method_list', type=str, nargs='+', default=['cot', 'pal', 'codenl', 'nlcode','nlcode_single'])
    # args = parser.parse_args()
    # # map the args to the actual model names

    # model_name = model_map[args.model]
    # print(f"Selected model: {model_name}")
    
    # ############### set the hyperparameter here ###############
    # is_train = args.is_train
    # verbose = args.verbose
    # zero_shot = args.zero_shot
    # num_shots = args.num_shots

    # method_list = args.method_list
    # hard_level_list = ['Level ' + level for level in args.hard_level_list]
    
    # # for hard_level in ['Level 3', 'Level 4', 'Level 5']:
    # for is_train in [True, False]:
    #     for hard_level in hard_level_list:
    #         Math_level_greedy(model_name, verbose=verbose, method_list=method_list, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level)
    model_name = 'gpt-4o'
    for hard_level in ['Level 5', 'Level 4', 'Level 3', 'Level 2', 'Level 1']:
        Math_level_4_shots_cot(model_name, verbose=False, zero_shot=False, hard_level=hard_level)