import json
import os
import numpy as np
from util import is_equiv, extract_math_answer


def compute_accuracy(gt, pred_solutions, prob_level, prob_type):
    pred_answers = []
    for pred_solution in pred_solutions:
        try: 
            pred_answer = extract_math_answer(pred_solution)
            pred_answers.append(pred_answer)
        except: 
            pred_answers.append('No answer')
    pred_answer = most_frequent(pred_answers)
    gt = extract_math_answer('\\boxed{' + gt + '}')
    try:
        is_correct = is_equiv(gt, pred_answer)
    except:
        is_correct = False
    if is_correct:
        return 1
    else:
        return 0

def most_frequent(List):
    counter = 0
    num = List[0]

    for i in List:
        current_frequency = sum(is_equiv(i, item) for item in List)
        if current_frequency > counter:
            counter = current_frequency
            num = i

    return num

if __name__ == "__main__":
    accuracies = []
    accuracies_algebra = []
    accuracies_counting = []
    accuracies_geometry = []
    accuracies_intermediate = []
    accuracies_number = []
    accuracies_prealgebra = []
    accuracies_precalculus = []
    resp_cnt = 0
    total_prompt_tokens, total_completion_tokens = 0, 0

    details = {'algebra_': [], 'counting_': [], 
               'geometry_': [], 'intermediate_algebra_': [],
               'number_': [], 'prealgebra_': [], 'precalculus_': [], }
    folder = 'MATH OUTPUT'
    response_dict = []
    files = os.listdir(folder)
    for file in files: 
        with open(os.path.join(folder, file), 'r', encoding='utf-8') as f: 
            data = json.load(f)
        response_dict.append(data)
    for d_idx, (data, file) in enumerate(zip(response_dict, files)): 
        responses, gt, prob_level, prob_type = data['agent_contexts'], data['answer'], data['level'], data['type']
        pred_solutions = []
        max_len = max([len(response) for response in responses])
        for response in responses:
            if len(response) < max_len:
                continue
            pred_solution = response[-1]['content']
            pred_solutions.append(pred_solution)
        accurate = compute_accuracy(gt, pred_solutions, prob_level, prob_type)
        if accurate is not None:
            accuracies.append(float(accurate))
            if prob_type == 'Algebra': 
                accuracies_algebra.append(float(accurate))
            if prob_type == 'Number Theory': 
                accuracies_number.append(float(accurate))
            if prob_type == 'Geometry': 
                accuracies_geometry.append(float(accurate))
            if prob_type == 'Intermediate Algebra': 
                accuracies_intermediate.append(float(accurate))
            if prob_type == 'Counting & Probability': 
                accuracies_counting.append(float(accurate))
            if prob_type == 'Prealgebra': 
                accuracies_prealgebra.append(float(accurate))
            if prob_type == 'Precalculus': 
                accuracies_precalculus.append(float(accurate))
            for key in details.keys():
                if file.startswith(key):
                    details[key].append(float(accurate))
                    break
    print("accuracies:", np.mean(accuracies))
    print("algebra:", np.mean(accuracies_algebra))
    print("counting:", np.mean(accuracies_counting))
    print("geometry:", np.mean(accuracies_geometry))
    print("intermediate_algebra:", np.mean(accuracies_intermediate))
    print("number:", np.mean(accuracies_number))
    print("prealgebra:", np.mean(accuracies_prealgebra))
    print("precalculus:", np.mean(accuracies_precalculus))
