from experiment import FOLIOExperiment
from eval import eval_file_folio
from utils.utils import extract_hard_level_file
import argparse
import os

def extract_hard_level(model_name, dataset_list, method_list, level='Level 5'):
    """
    extract the level 5 problem and its results from the whole file
    """ 
    for method in method_list:
        for dataset in dataset_list:
            file = f'new_results/{model_name}_{dataset}_{method}.jsonl'
            extract_hard_level_file(file, level)
            print(f"Extracted {level} problems from {file}")


def folio_experiment(model_name, verbose, zero_shot, method_list=['cot','pal'], is_train=False, hard_level=None):
    """
    This function is used to run the experiment on the level 5 problems of MATH dataset
    It will automatically check the results of exsitinng results and running the left experiments
    """

    dataset = 'folio'
    
    # verbose = False
    for method in method_list:
        scores = []
        
        print(f"Running experiment on {dataset} for method {method}")

        
        store_path = f"folio_results/{model_name}_{dataset}_{method}.jsonl"

        if is_train:
            store_path = store_path.replace(".jsonl", "_train.jsonl")

        if os.path.exists(store_path):
            print(f"Results already exist in {store_path}")
        else:
            exp = FOLIOExperiment(model_name, dataset, method, store_path, hard_level, verbose=verbose, zero_shot=zero_shot, is_train=is_train, num_shots=num_shots)
            exp.run()
        
        scores += eval_file_folio(store_path)

        if len(scores) == 0:
            continue
        accuracy = sum(scores) / len(scores)
        print(f"Average accuracy for model {model_name} of method {method} is {accuracy}")



if __name__ == '__main__':
    
    model_map = {
            "gpt-4o-mini": "gpt-4o-mini",
            "llama-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
            "llama-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
            "gpt-4o": "gpt-4o"
        }
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, choices=model_map.keys())
    parser.add_argument('--zero_shot', action='store_true')
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--is_train', action='store_true')
    parser.add_argument('--num_shots', type=int, default=1)
    parser.add_argument('--hard_level_list', type=str, nargs='+', default=['5', '4', '3', '2', '1'])
    parser.add_argument('--method_list', type=str, nargs='+', default=['cot','pal'])

    args = parser.parse_args()
    # map the args to the actual model names

    model_name = model_map[args.model]
    print(f"Selected model: {model_name}")
    
    ############### set the hyperparameter here ###############
    is_train = args.is_train
    verbose = args.verbose
    zero_shot = args.zero_shot
    num_shots = args.num_shots

    method_list = args.method_list
    hard_level_list = ['Level ' + level for level in args.hard_level_list]
    
    # for hard_level in ['Level 3', 'Level 4', 'Level 5']:
    for is_train in [False, True]:
        folio_experiment(model_name, verbose=verbose, method_list=method_list, zero_shot=zero_shot, is_train=is_train)