from utils.utils import read_jsonl
import os
from collections import defaultdict

def compare_file(file_path):
    '''
    Compare the result on different levels in the given file.
    
    Args:
        file_path: str, the path to the file
    '''

    assert os.path.exists(file_path), f"File {file_path} does not exist."
    assert file_path.endswith('eval.jsonl'), f"File {file_path} is not an evaled file."

    results = read_jsonl(file_path)
    
    level_counts = defaultdict(lambda: {'correct': 0, 'total': 0})

    for result in results:
        level = result['level']
        correct = result['correct']
        
        # Update counts
        level_counts[level]['total'] += 1
        if correct:
            level_counts[level]['correct'] += 1

    # Calculate and print correct rates
    for level, counts in level_counts.items():
        correct_rate = counts['correct'] / counts['total'] * 100
        print(f"Level {level}: Correct Rate = {correct_rate:.2f}%")

    return level_counts

def get_result_method(model_name, method):
    """
    Compute the average result of different level data.
    
    Args:
        model_name: str, the model name
        method: str, the method such as cot, pal, codenl, nlcode
    
    Returns:
        total_level_counts: defaultdict, accumulated counts across all data
    """
    data_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    
    # Initialize a defaultdict to accumulate counts
    total_level_counts = defaultdict(lambda: {'correct': 0, 'total': 0})
    
    for data in data_list:
        result_file = f'results/{model_name}_{data}_{method}_eval.jsonl'
        if os.path.exists(result_file):
            level_counts = compare_file(result_file)
            
            # Aggregate counts
            for level, counts in level_counts.items():
                total_level_counts[level]['total'] += counts['total']
                total_level_counts[level]['correct'] += counts['correct']

    # Calculate and print correct rates for total levels
    for level, counts in total_level_counts.items():
        correct_rate = counts['correct'] / counts['total'] * 100 if counts['total'] > 0 else 0
        print(f"Total Level {level}: Correct Rate = {correct_rate:.2f}%")
    

    return total_level_counts


if __name__ == '__main__':
    # compare_file('results\meta-llama\Meta-Llama-3.1-8B-Instruct-Turbo_algebra_cot_eval.jsonl')
    model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    method = 'cot'
    get_result_method(model_name, method)
