from compare import get_result_method
import numpy as np
import os
from utils.utils import read_jsonl

def get_label(file):
    """
    get the label in the given file

    Args:
        file: str, the path to the file
    """
    assert os.path.exists(file), f"File {file} does not exist."
    assert file.endswith('eval.jsonl'), f"File {file} is not an evaled file."

    results = read_jsonl(file)
    labels = []
    for result in results:
        labels.append(result['correct'])

    return labels

def get_labels_for_methods(method, model_name):
    """
    get the labels for the given method and model

    Args:
        method: str, the method
        model_name: str, the model name
    """
    data_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    labels = []
    for data in data_list:
        result_file = f'results/{model_name}_{data}_{method}_eval.jsonl'
        labels += get_label(result_file)

    return labels



def compute_confusion(method_list, model_name):
    """
    Compute the confusion matrix for the methods.
    Args:
        method_list: list, the list of methods
        model_name: str, the model name
    Returns:
        confusion_matrix: np.array, the confusion matrix
    """
    # Initialize the confusion matrix
    confusion_matrix = np.zeros((len(method_list), len(method_list)), dtype=int)
    
    # Get the labels for each method
    method_labels = []
    for method in method_list:
        labels = get_labels_for_methods(method, model_name)
        method_labels.append(labels)
    
    # Compare labels pairwise to fill the confusion matrix
    for i in range(len(method_list)):
        for j in range(len(method_list)):
            # Count the number of times method i and method j correctly labeled the same samples
            confusion_matrix[i, j] = np.sum(np.array(method_labels[i]) == np.array(method_labels[j]))

    return confusion_matrix



if __name__ == '__main__':
    method_list = ['cot', 'pal', 'codenl', 'nlcode']
    # model_name = "gpt-4o-mini"
    model_name = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
    confusion_matrix = compute_confusion(method_list, model_name)
    print(confusion_matrix)