import os
import openai
import json
import numpy as np
import time
import pickle
import argparse
from tqdm import tqdm
# openai.api_key = ""

def parse_bullets(sentence):
    bullets_preprocess = sentence.split("\n")
    bullets = []

    for bullet in bullets_preprocess:
        try:
            idx = bullet.find(next(filter(str.isalpha, bullet)))
        except:
            continue

        bullet = bullet[idx:]

        if len(bullet) != 0:
            bullets.append(bullet)

    return bullets

def generate_answer(answer_context, model = "gpt3.5-06"):
    if model == "gpt3.5":
        model_str = "gpt-3.5-turbo-0301"
    elif model == "gpt3.5-06":
        model_str = "gpt-3.5-turbo-0613"
    elif model == "ft-gpt3.5":
        model_str = None
    else:
        model_str = "gpt-4-0613"
    try:
        completion = openai.ChatCompletion.create(
                  model=model_str,
                  messages=answer_context,
                  n=1)
    except:
        print("retrying due to an error......")
        time.sleep(20)
        return generate_answer(answer_context)

    return completion

def construct_message(agents, question, agent_indices, current_agent):

    # Use introspection in the case in which there are no other agents.
    if len(agents) == 0:
        return {"role": "user", "content": "Can you verify that your answer is correct. Please reiterate your answer, making sure to state your answer at the end of the response."}

    prefix_string = "These are the solutions to the problem from other agents: "

    for i, agent in enumerate(agents):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-2]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent solution: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + "\n\n Use these solutions carefully as additional advice, can you provide an updated answer to the problem: \n{} \n\nExplain your answer. Make sure to state your answer at the end of the response.".format(question)
    return {"role": "user", "content": prefix_string}


def construct_assistant_message(completion):
    content = completion["choices"][0]["message"]["content"]
    return {"role": "assistant", "content": content}

def summarize_message(agent_contexts, question_prompt, agent_indices, current_agent):
    prefix_string = f"Here are a list of solutions from different agents to the problem, {question_prompt}: "

    for i, agent in enumerate(agent_contexts):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-3]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent response: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + "\n\n Write a summary of the solutions from each of the individual agent and briefly explain their reasoning."
    agent_context = [{"role": "user", "content": prefix_string}]
    completion = generate_answer(agent_context)
    content = completion["choices"][0]["message"]["content"]
    print(completion)
    return content

def construct_sum_debate(content, question):
    prefix_string = f"These are the solutions to the problem from other agents: \n\n{content}"
    prefix_string = prefix_string + "\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}. Please explain your answer. Make sure to state your answer at the end of the response.".format(question)
    return {"role": "user", "content": prefix_string}

def parse_answer(sentence):
    parts = sentence.split(" ")

    for part in parts[::-1]:
        try:
            answer = float(part)
            return answer
        except:
            continue


def most_frequent(List):
    counter = 0
    num = List[0]

    for i in List:
        current_frequency = List.count(i)
        if current_frequency > counter:
            counter = current_frequency
            num = i

    return num


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--agents", action = "store", dest = "agents", type = int, required = True, help = "Number of agents for debate")
    parser.add_argument("--rounds", action = "store", dest = "rounds", type = int, required = True, help = "Number of rounds for debate")
    parser.add_argument("--model", action = "store", default = "gpt3.5", type = str, choices = ["gpt3.5", "gpt4", "ft-gpt3.5", "gpt3.5-06"])
    parser.add_argument("--summarize", action = "store_true", dest = "summarize")
    parser.set_defaults(gpt4 = False, summarize = False)
    args = parser.parse_args()
    agents = args.agents
    rounds = args.rounds
    np.random.seed(0)

    evaluation_round = 1000
    scores = []

    generated_description = {}

    for round in tqdm(range(evaluation_round)):
        a, b, c, d, e, f = np.random.randint(0, 30, size=6)

        answer = a + b * c + d - e * f
        agent_contexts = [[{"role": "user", "content": """What is the result of {}+{}*{}+{}-{}*{}? Please explain your answer. Make sure to state your answer at the end of the response.""".format(a, b, c, d, e, f)}] for agent in range(agents)]

        content = agent_contexts[0][0]['content']
        question_prompt = "What is the result of {}+{}*{}+{}-{}*{}? Please explain your answer. Make sure to state your answer at the end of the response.".format(a, b, c, d, e, f)
        response = None
        for round in range(rounds):
            for i, agent_context in enumerate(agent_contexts):

                if round != 0:
                    agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
                    other_agent_indices = list(range(0, i)) + list(range(i+1, agents))
                    if not args.summarize:
                        message = construct_message(agent_contexts_other, question_prompt, other_agent_indices, i)
                        print(message)
                    else:
                        sum_content = summarize_message(agent_contexts_other, question_prompt, other_agent_indices, i)
                        message = construct_sum_debate(sum_content, question_prompt)

                completion = generate_answer(agent_context, model = args.model)
                assistant_message = construct_assistant_message(completion)
                agent_context.append(assistant_message)
                print(completion)

        text_answers = []

        for agent_context in agent_contexts:
            text_answer = string =  agent_context[-1]['content']
            text_answer = text_answer.replace(",", ".")
            text_answer = text_answer.replace("$", "")
            text_answer = text_answer.replace("\\boxed{", "")
            text_answer = text_answer.replace("}", "")
            text_answer = text_answer.replace("=", "= ")
            text_answer = text_answer.replace("\n", " \n ")
            text_answer = parse_answer(text_answer)

            if text_answer is None:
                continue

            text_answers.append(text_answer)

        generated_description[(a, b, c, d, e, f)] = (agent_contexts, answer)

        try:
            text_answer = most_frequent(text_answers)
            if text_answer == answer:
                scores.append(1)
            else:
                scores.append(0)
        except:
            print("Skipping...")
            continue

        print("performance:", np.mean(scores), np.std(scores) / (len(scores) ** 0.5))
        print(len(scores))
        performance = (np.mean(scores), np.std(scores) / (len(scores) ** 0.5))

    print(performance)
    if args.model == "gpt4":
        model_str = "gpt4"
    else:
        if args.model == "gpt3.5":
            model_str = "gpt35"
        elif args.model == "gpt3.5-06":
            model_str = "gpt35-06"
        else:
            model_str = "ft-gpt35-5k-final-epochs4"
    pickle.dump(generated_description, open("{}_math_{}_{}.p".format(model_str, agents, rounds), "wb"))