from glob import glob
import pandas as pd
import json
import time
import random
import openai
import argparse
from tqdm import tqdm

# openai_key = ""

def construct_message(agents, question, agent_indices, current_agent):
    if len(agents) == 0:
        return {"role": "user", "content": "Can you double check that your answer is correct. Put your final answer in the form (X) at the end of your response."}

    prefix_string = "These are the solutions to the problem from other agents: "

    for i, agent in enumerate(agents):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-2]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent solution: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + """\n\n Using the reasoning from other agents as additional advice, can you give an updated answer? Examine your solution and that other agents step by step. You must put your answer in the form (X) at the end of your response.""".format(question)
    return {"role": "user", "content": prefix_string}

def construct_assistant_message(completion):
    content = completion["choices"][0]["message"]["content"]
    return {"role": "assistant", "content": content}

def summarize_message(agent_contexts, question_prompt, agent_indices, current_agent):
    prefix_string = f"Here are a list of opinions different agents with the confidence in their opinion to the question, {question_prompt}: "

    for i, agent in enumerate(agent_contexts):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-3]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent response: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + "\n\n Please summarize the responses from different agents by consolidating the responses from the agents into one response for the given question"
    agent_context = [{"role": "user", "content": prefix_string}]
    completion = generate_answer(agent_context)
    print(completion)
    content = completion["choices"][0]["message"]["content"]

    return content

#BEST ONE SO FAR
def construct_sum_debate(content, question):
    prefix_string = f"These are the recent/updated opinions with confidence scores out of 100 from other agents: \n\n{content}"
    prefix_string = prefix_string + "\n\n Using these opinions carefully as additional advice, can you provide an updated answer to the question {}\n\nExplain your answer. Make sure to state your answer in the form \"(X)\" at the end of the response.".format(question)
    # prefix_string = prefix_string + "\n\n Using these opinions carefully as additional advice, along with your own prior responses and confidence, can you provide an updated answer to the question {}\n\nExplain your answer. Make sure to state your answer in the form \"(X)\" at the end of the response.".format(question)
    return {"role": "user", "content": prefix_string}


def generate_answer(answer_context, model = "gpt3.5"):
    if model == "gpt3.5":
        model_str = "gpt-3.5-turbo-0301"
    elif model == "gpt3.5-06":
        model_str = "gpt-3.5-turbo-0613"
    elif model == "ft-gpt3.5":
        model_str = None
    else:
        model_str = "gpt-4-0613"
    try:
        completion = openai.ChatCompletion.create(
                  model=model_str,
                  messages=answer_context,
                  n=1)
    except:
        print("retrying due to an error......")
        time.sleep(20)
        return generate_answer(answer_context)

    return completion

def parse_question_answer(question):
    sentence = question["sentence"]
    option1 = question["option1"]
    option2 = question["option2"]
    prompt = f"Can you fill in the blank in the following sentence using one of the two given options?\n\n{sentence}\nOption 1:{option1} Option 2:{option2}\n\nGive your answer as a number, so 1 for option 1 or 2 for option 2 at the end of your response."
    return prompt

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--agents", action = "store", dest = "agents", type = int, required = True, help = "Number of agents for debate")
    parser.add_argument("--rounds", action = "store", dest = "rounds", type = int, required = True, help = "Number of rounds for debate")
    parser.add_argument("--model", action = "store", default = "gpt3.5", type = str, choices = ["gpt3.5", "gpt4", "ft-gpt3.5", "gpt3.5-06"])
    parser.add_argument("--summarize", action = "store_true", dest = "summarize")
    parser.add_argument("--temperature", action = "store", type = float, dest = "temperature", default = 1)
    parser.add_argument("--top_p", action = "store", type = float, dest = "top_p", default = 1)
    parser.set_defaults(summarize = False)
    args = parser.parse_args()

    agents = args.agents
    rounds = args.rounds
    if args.model == "gpt4" or args.model == "gpt3.5-06" or args.model == "ft-gpt3.5":
        assert agents == rounds == 1, "Only run GPT-4 with one agent, one round"
    summarize_str = " summarized" if args.summarize else ""
    print(f"Running{summarize_str} Winogrande Fine-tuned Evaluation with {agents} agents and {rounds} rounds with {args.model}")

    with open("winogrande_1.1/dev.jsonl", "r") as f:
        questions = [json.loads(l) for l in f]
    with open("winogrande_1.1/dev-labels.lst", "r") as f:
        answers = f.readlines()
    assert len(questions) == len(answers)

    random.seed(0)
    response_dict = {}

    for idx in tqdm(range(len(questions)), "Evaluating Winogrande"):
        og_question = questions[idx]
        answer = answers[idx]
        question= parse_question_answer(og_question)
        agent_contexts = [[{"role": "user", "content": question}] for agent in range(agents)]
        for round in range(rounds):
            for i, agent_context in enumerate(agent_contexts):

                if round != 0:
                    agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
                    other_agent_indices = list(range(0, i)) + list(range(i+1, agents))
                    agent_indices = list(range(agents))
                    if not args.summarize:
                        message = construct_message(agent_contexts_other, question, other_agent_indices, i)
                    else:
                        sum_content = summarize_message(agent_contexts_other, question, other_agent_indices, i)
                        message = construct_sum_debate(sum_content, question)
                    agent_context.append(message)

                completion = generate_answer(agent_context, model = args.model)

                assistant_message = construct_assistant_message(completion)
                agent_context.append(assistant_message)
                print(completion)

        response_dict[question] = (agent_contexts, answer)
    if args.model == "gpt4":
        model_str = "gpt4"
    else:
        if args.model == "gpt3.5":
            model_str = "gpt35"
        elif args.model == "gpt3.5-06":
            model_str = "gpt35-06"
        else:
            model_str = "ft-gpt35-5k-final-epochs6"
    json.dump(response_dict, open("{}-winogrande.json".format(model_str, agents, rounds), "w"))