import json
import numpy as np
import re
import os

def solve_math_problems(input_str):
    pattern = r"\d+\.?\d*"

    matches = re.findall(pattern, input_str)
    if matches:
        return matches[-1]

    return None

def parse_answer(input_str):
    pattern = r'\((\w)\)'
    matches = re.findall(pattern, input_str)
    solution = None
    for match_str in matches[::-1]:
        solution = match_str.upper()
        if solution not in 'ABCD': 
            continue
        if solution:
            break
    return solution

def compute_accuracy(gt, pred_solutions):
    if type(pred_solutions) == list:
        pred_answers = []
        for pred_solution in pred_solutions:
            pred_answer = parse_answer(pred_solution)

            if pred_answer is None:
                pred_answer = 'None'

            if pred_answer is not None:
                pred_answers.append(pred_answer)
        pred_answer = most_frequent(pred_answers)
    else:
        pred_answer = parse_answer(pred_solutions)
        if pred_answer is None:
            pred_answer = solve_math_problems(pred_solutions)
    if gt == pred_answer:
        return 1
    else:
        return 0

def most_frequent(List):
    counter = 0
    num = List[0]

    for i in List:
        if i == 'None': 
            continue
        current_frequency = List.count(i)
        if current_frequency > counter:
            counter = current_frequency
            num = i

    return num

if __name__ == "__main__":
    DIR = "MMLU OUTPUT"
    files = os.listdir(DIR)
    response_dict = {}
    for file in files: 
        with open(os.path.join(DIR, file), 'r', encoding='utf-8') as f: 
            d = json.load(f)
        response_dict.update(d)
    questions = list(response_dict.keys())
    accuracies = []
    accuracies_stem = []
    accuracies_socialscience = []
    accuracies_humanity = []
    accuracies_other = []
    task_names = []
    for q_idx, task_idx in enumerate(questions):
        question, task = response_dict[task_idx].keys()
        responses, gt = response_dict[task_idx][question]
        task_name = response_dict[task_idx]['task']
        task_name = task_name.split('_val.csv')[0]
        if task_name not in task_names: 
            task_names.append(task_name)
        pred_solutions = []
        for r_idx, response in enumerate(responses):
            pred_solution = response[-1]['content']
            pred_solutions.append(pred_solution)
        accurate = compute_accuracy(gt, pred_solutions)
        file = task_idx + '.json'
        if accurate is not None:
            accuracies.append(float(accurate))
            if task_name in [
                'abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 
                'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 
                'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 
                'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning'
            ]: 
                accuracies_stem.append(accurate)
            if task_name in [
                'econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 
                'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 
                'public_relations', 'security_studies', 'sociology', 'us_foreign_policy'
            ]: 
                accuracies_socialscience.append(accurate)
            if task_name in [
                'formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 
                'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 
                'prehistory', 'professional_law', 'world_religions'
            ]: 
                accuracies_humanity.append(accurate)
            if task_name in [
                'business_ethics', 'management', 'marketing', 'professional_accounting', 
                'clinical_knowledge', 'college_medicine', 'medical_genetics', 'nutrition', 'professional_medicine', 'virology', 
                'global_facts', 'human_aging', 'miscellaneous'
            ]: 
                accuracies_other.append(accurate)

    print("accuracies:", np.mean(accuracies))
    print("stem:", np.mean(accuracies_stem))
    print("socialscience:", np.mean(accuracies_socialscience))
    print("humanity:", np.mean(accuracies_humanity))
    print("other:", np.mean(accuracies_other))
