import json
import openai
import numpy as np
import time
import re
import argparse

def parse_bullets(sentence):
    bullets_preprocess = sentence.split("\n")
    bullets = []

    for bullet in bullets_preprocess:
        try:
            idx = bullet.find(next(filter(str.isalpha, bullet)))
        except:
            continue

        bullet = bullet[idx:]

        if len(bullet) != 0:
            bullets.append(bullet)

    return bullets


def parse_yes_no(string):
    """
    Parses a string containing "yes" or "no" and returns a boolean value.

    Args:
        string (str): The string to parse.

    Returns:
        bool: True if the string contains "yes", False if the string contains "no".

    Raises:
        ValueError: If the input string does not contain "yes" or "no".
    """
    if "yes" in string.lower():
        return True
    elif "no" in string.lower():
        return False
    else:
        return None


def solve_math_problems(input_str):
    pattern = r"\d+\.?\d*"

    matches = re.findall(pattern, input_str)
    if matches:
        return matches[-1]

    return None

def parse_answer(input_str):
    pattern = r'\((\w)\)'
    matches = re.findall(pattern, input_str)

    solution = None
    # print("predicted solution")
    # print(input_str)
    # print("matches")
    # print(matches)

    for match_str in matches[::-1]:
        solution = match_str.upper()
        if solution:
            break

    return solution

def compute_accuracy(gt, pred_solutions):
    if type(pred_solutions) == list:
        pred_answers = []

        for pred_solution in pred_solutions:
            pred_answer = parse_answer(pred_solution)

            if pred_answer is None:
                pred_answer = solve_math_problems(pred_solution)

            if pred_answer is not None:
                pred_answers.append(pred_answer)

        if pred_answer is None:
            return 1
        pred_answer = most_frequent(pred_answers)
    else:
        pred_answer = parse_answer(pred_solutions)
        if pred_answer is None:
            pred_answer = solve_math_problems(pred_solutions)
    if gt == pred_answer:
        return 1
    else:
        print(question)
        print(pred_solutions, pred_answer, f'{gt}\n')
        return 0


def most_frequent(List):
    counter = 0
    num = List[0]

    for i in List:
        current_frequency = List.count(i)
        if current_frequency > counter:
            counter = current_frequency
            num = i

    return num

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--agents", action = "store", dest = "agents", type = int, required = True, help = "Number of agents for debate")
    parser.add_argument("--rounds", action = "store", dest = "rounds", type = int, required = True, help = "Number of rounds for debate")
    parser.add_argument("--model", action = "store", default = "gpt3.5", type = str, choices = ["gpt3.5", "gpt4", "ft-gpt3.5", "gpt3.5-06"])
    parser.add_argument("--summarize", action = "store_true", dest = "summarize")
    parser.add_argument("--summarize_type", action = "store", type = str, dest = "summarize_type", default = None)
    args = parser.parse_args()

    response_path = "{}_mmlu_{}_{}.json"
    if args.model == "gpt4":
        model_str = "gpt4"
    elif args.model == "gpt3.5-06":
        model_str = "gpt35-06"
    elif args.model == "ft-gpt3.5":
        model_str = "ft-gpt35-2k-final-epochs6" 
    response_dict = json.load(open(response_path.format(model_str, args.agents, args.rounds), "r"))

    questions = list(response_dict.keys())

    accuracies = []
    print(len(questions))

    for question in questions:
        responses, gt = response_dict[question]

        pred_solutions = []
        for response in responses:
            pred_solution = response[-1]['content']

            pred_solutions.append(pred_solution)

        accurate = compute_accuracy(gt, pred_solutions)


        if accurate is not None:
            accuracies.append(float(accurate))
        else:
            import pdb
            pdb.set_trace()
            print(gt)

    print("accuracies:", np.mean(accuracies), np.std(accuracies) / (len(accuracies) ** 0.5))
    print(response_path.format(model_str, args.agents, args.rounds))
