from utils.utils import count_tokens
from utils.utils import read_jsonl

def record_tokens(model_name, dataset_list, hard_level, method_list):

    
    for method in method_list:
        total_tokens = 0
        total_problems = 0
        for dataset in dataset_list: 
            tokens = 0
            if 'gpt' in model_name:
                data_path = f'results/gpt/{model_name}_{dataset}_{hard_level}_{method}.jsonl'
            else:
                data_path = f'results/{model_name}_{dataset}_{hard_level}_{method}.jsonl'
            data = read_jsonl(data_path)
            
            problems = len(data)

            error_numbs = 0
            for row in data:
                try:
                    if method in ['cot', 'pal', 'metamath']:
                        # print(row['reasoning_path'])
                        answer_tokens = count_tokens(row['reasoning_path'])
                    elif method in ['nlcode', 'codenl']:
                        answer_tokens = count_tokens(row['reasoning_path_1']) + count_tokens(row['reasoning_path_2'])

                    tokens += answer_tokens
                except:
                    error_numbs += 1
            # print(f"Error numbs for model {model_name} on dataset {dataset} is {error_numbs}")
            problems -= error_numbs # remove the error problems
            average_tokens = tokens / problems

            # print(f"Average tokens for model {model_name} on dataset {dataset} is {average_tokens}")
            total_tokens += tokens
            total_problems += problems

        total_average_tokens = total_tokens / total_problems
        print(f"Total average tokens for method {method} is {total_average_tokens}")

            
            # for row in data:
            #     question = row['question']
            #     answer = row['answer']
            #     question_tokens = count_tokens(question)
            #     answer_tokens = count_tokens(answer)
            #     row['question_tokens'] = question_tokens
            #     row['answer_tokens'] = answer_tokens
            # write_jsonl(data, data_path)

if __name__ == "__main__":
    dataset_list = ['counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']
    hard_level = "Level 5"
    # model_name = 'gpt-4o-mini'
    model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    method_list = ['cot', 'pal', 'codenl', 'nlcode', 'metamath']
    record_tokens(model_name, dataset_list, hard_level, method_list)