from glob import glob
import pandas as pd
import json
import time
import random
import openai
import argparse
from tqdm import tqdm

# openai.api_key = ''

"Can you answer the following question as accurately as possible? {}: A) {}, B) {}, C) {}, D) {}\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response."

EXAMPLE_ARC_QUESTIONS = [
    "Can you answer the following question as accurately as possible? Colton and his dad bought a gallon of paint that cost $13. They also bought 2 brushes that cost $9 each. What was the total cost, not including tax, of the brushes and the paint they bought?: A) $22, B) $24, C) $31, D) $35\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response.",
    "Can you answer the following question as accurately as possible? Which of the following fallacies happens when someone argues that a claim should be accepted based on evidence that is not presented, but is asserted to be well known or obvious?: A) self evident truths, B) suppressed evidence, C) appeal to pride, D) ad populum\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response.",
    "Can you answer the following question as accurately as possible? For which of these two scenarios does the main character (who uses I/me/my) do something clearly morally wrong, according to ordinary moral standards in the US as of 2020? Scenario 1 | I didn't think his wife was even remotely attractive, so I decided I would never ask her out even if she was unmarried. Scenario 2 | I chased my sister with a hot iron.: A) Wrong, Wrong, B) Wrong, Not wrong, C) Not wrong, Wrong, D) Not wrong, Not wrong\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response.",
    "Can you answer the following question as accurately as possible? Which American colony known for its religious tolerance did Roger Williams found in 1636? A) Massachusetts, B) Rhode Island, C) Virginia, D) Vermont\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response.",
    "Can you answer the following question as accurately as possible? Suppose it takes 1 second to factor a general 100 x 100 matrix using Gaussian elimination. Of the following, which is the best estimate of the number of seconds it will take to factor a 500 x 500 matrix based on the relative dimensions? A) 5, B) 10, C) 25, D) 125\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response."
]

EXAMPLE_ARC_ANSWERS = [
    "Colton and his dad first buy a gallon of paint that costs $13. They then buy 2 brushes that cost 2x9=18 dollars. In total, they spend 13+18 = $31. Therefore, the answer is (C).",
    "When a person argues for a claim to be accepted because it is well-known or true, they are referring to a self-evident truth. These are statements that are not substantiatede by evidence but instead backed up by general universality. Therefore, the answer is (A).",
    "Scenario 1 is not morally wrong. The subject is avoiding asking a married woman out, albiet it could be considered inappropriate to comment on someone's looks, whether unmarried or not. However, chasing someone with a hot iron can be considered assault. The answer is (C).",
    "Roger Williams founded Rhode Island in 1636 as a safe-haven for the Puritans, a Christian religious who were persecuted in England at the time. Therefore, the answer is (B)",
    "Matrix factorization takes O(n^3) time, where n is the dimension of the matrix, assuming the matrix is nxn in the dimension. Since we are expanding the dimensions of the original matrix by 5, and we are given that it takes 1 second to factor a 100x100 matrix, we have 1*5^3 as the amount of time to factor a 500x500 matrix. So, the answer is (D)."
]

SYSTEM_PROMPT = f"""You are a helpful assistant that will solve mutliple choice problems covering many different subjects such as law, mathematics, science, psychology, etc. Here are some example problems and formatted responses with explanations:

Question 1: {EXAMPLE_ARC_QUESTIONS[0]}
Answer 1: {EXAMPLE_ARC_ANSWERS[0]}

Question 2: {EXAMPLE_ARC_QUESTIONS[1]}
Answer 2: {EXAMPLE_ARC_ANSWERS[1]}

Question 3: {EXAMPLE_ARC_QUESTIONS[2]}
Answer 3: {EXAMPLE_ARC_ANSWERS[2]}

Question 4: {EXAMPLE_ARC_QUESTIONS[3]}
Answer 4: {EXAMPLE_ARC_ANSWERS[3]}

Question 5: {EXAMPLE_ARC_QUESTIONS[4]}
Answer 5: {EXAMPLE_ARC_ANSWERS[4]}

Please format your final answer to the question in the format \"(X)\" as done in the examples and always choose one of the answer choices.
"""

SYSTEM_MESSAGE = {
    "role": "system",
    "content": SYSTEM_PROMPT
}

def construct_message(agents, question, agent_indices, current_agent):
    if len(agents) == 0:
        return {"role": "user", "content": "Can you double check that your answer is correct. Put your final answer in the form (X) at the end of your response."}

    prefix_string = "These are the solutions to the problem from other agents: "

    for i, agent in enumerate(agents):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-2]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent solution: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + """\n\n Using the reasoning from other agents as additional advice, can you give an updated answer? Examine your solution and that other agents step by step. You must put your answer in the form (X) at the end of your response.""".format(question)
    return {"role": "user", "content": prefix_string}

def construct_assistant_message(completion):
    content = completion["choices"][0]["message"]["content"]
    return {"role": "assistant", "content": content}

def summarize_message(agent_contexts, question_prompt, agent_indices, current_agent):
    prefix_string = f"Here are a list of opinions different agents with the confidence in their opinion to the question, {question_prompt}: "

    for i, agent in enumerate(agent_contexts):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-3]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent response: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + "\n\n Please summarize the responses from different agents by consolidating the responses from the agents into one response for the given question"
    agent_context = [{"role": "user", "content": prefix_string}]
    completion = generate_answer(agent_context)
    print(completion)
    content = completion["choices"][0]["message"]["content"]

    return content

#BEST ONE SO FAR
def construct_sum_debate(content, question):
    prefix_string = f"These are the recent/updated opinions with confidence scores out of 100 from other agents: \n\n{content}"
    prefix_string = prefix_string + "\n\n Using these opinions carefully as additional advice, can you provide an updated answer to the question {}\n\nExplain your answer. Make sure to state your answer in the form \"(X)\" at the end of the response.".format(question)
    # prefix_string = prefix_string + "\n\n Using these opinions carefully as additional advice, along with your own prior responses and confidence, can you provide an updated answer to the question {}\n\nExplain your answer. Make sure to state your answer in the form \"(X)\" at the end of the response.".format(question)
    return {"role": "user", "content": prefix_string}


def generate_answer(answer_context, model = "gpt3.5"):
    if model == "gpt3.5":
        model_str = "gpt-3.5-turbo-0301"
    elif model == "gpt3.5-06":
        model_str = "gpt-3.5-turbo-0613"
    elif model == "ft-gpt3.5":
        model_str = None #NOTE: PUT MODEL STRING HERE
    else:
        model_str = "gpt-4-0613"
    try:
        completion = openai.ChatCompletion.create(
                  model=model_str,
                  messages=answer_context,
                  n=1)
    except:
        print("retrying due to an error......")
        time.sleep(20)
        return generate_answer(answer_context)

    return completion

def parse_question_answer(df, ix):
    question = df.iloc[ix]["question"]

    prompt = "Can you answer the following question as accurately as possible? {}\n\nExplain your answer. You must pick one of the choices. You are required to put only the final answer choice in the form \"(X)\" at the END of your response.".format(question)
    # prompt = prompt + "Additionally, rank your confidence in your answer on a scale from 1-100, 1 being least confident and 100 being most confident."
    answer = df.iloc[ix]["AnswerKey"]

    return prompt, answer, question

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--agents", action = "store", dest = "agents", type = int, required = True, help = "Number of agents for debate")
    parser.add_argument("--rounds", action = "store", dest = "rounds", type = int, required = True, help = "Number of rounds for debate")
    parser.add_argument("--model", action = "store", default = "gpt3.5", type = str, choices = ["gpt3.5", "gpt4", "ft-gpt3.5", "gpt3.5-06"])
    parser.add_argument("--summarize", action = "store_true", dest = "summarize")
    parser.add_argument("--temperature", action = "store", type = float, dest = "temperature", default = 1)
    parser.add_argument("--top_p", action = "store", type = float, dest = "top_p", default = 1)
    parser.set_defaults(summarize = False)
    args = parser.parse_args()

    agents = args.agents
    rounds = args.rounds
    if args.model == "gpt4" or args.model == "gpt3.5-06" or args.model == "ft-gpt3.5":
        assert agents == rounds == 1, "Only run GPT-4 with one agent, one round"
    summarize_str = " summarized" if args.summarize else ""
    print(f"Running{summarize_str} ARC multiagent debate with {agents} agents and {rounds} rounds with {args.model}")

    tasks = glob("data/test/*.csv")

    df = pd.read_csv("ARC-V1-Feb2018-2/ARC-Challenge/ARC-Challenge-Test.csv")
    print(len(df))

    random.seed(0)
    response_dict = {}

    for idx in tqdm(range(len(df)), "Evaluating ARC"):
        question, answer, og_question = parse_question_answer(df, idx)
        agent_contexts = [[{"role": "user", "content": question}] for agent in range(agents)]
        for round in range(rounds):
            for i, agent_context in enumerate(agent_contexts):

                if round != 0:
                    agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
                    other_agent_indices = list(range(0, i)) + list(range(i+1, agents))
                    agent_indices = list(range(agents))
                    if not args.summarize:
                        message = construct_message(agent_contexts_other, question, other_agent_indices, i)
                    else:
                        sum_content = summarize_message(agent_contexts_other, question, other_agent_indices, i)
                        message = construct_sum_debate(sum_content, question)
                    agent_context.append(message)

                completion = generate_answer(agent_context, model = args.model)

                assistant_message = construct_assistant_message(completion)
                agent_context.append(assistant_message)
                print(completion)

        response_dict[question] = (agent_contexts, answer)
    if args.model == "gpt4":
        model_str = "gpt4"
    else:
        if args.model == "gpt3.5":
            model_str = "gpt35"
        elif args.model == "gpt3.5-06":
            model_str = "gpt35-06"
        else:
            model_str = "ft-gpt35-5k-final-epochs6"
    json.dump(response_dict, open("{}_full-arc.json".format(model_str, agents, rounds), "w"))