import os, sys
import numpy as np
import pandas as pd
import time
import openai
import json
import argparse
from tqdm import tqdm

np.random.seed(42)
# openai.api_key = ''

def select_prompts(alpaca_df: pd.DataFrame, sample: int = 5000) -> pd.DataFrame:
    if os.path.exists("ft_data/alpaca-ft.csv"):
        ft_df = pd.read_csv("ft_data/alpaca-ft.csv")
    alpaca_df = alpaca_df[~(alpaca_df["prompt"].isin(ft_df["prompt"]))]
    debate_df = alpaca_df.sample(n = sample)
    return debate_df

def generate_answer(answer_context):
    model_name = "gpt-3.5-turbo-0301"
    try:
        completion = openai.ChatCompletion.create(
                  model=model_name,
                  messages=answer_context,
                  n=1)
    except:
        print("retrying due to an error......")
        time.sleep(20)
        return generate_answer(answer_context)

    return completion

def summarize_message(agent_contexts, question_prompt, agent_indices, current_agent):
    prefix_string = f"Here are a list of opinions different agents with the confidence in their opinion to the question, {question_prompt}: "

    for i, agent in enumerate(agent_contexts):
        agent_idx = agent_indices[i]
        if agent_idx < current_agent:
            agent_response = agent[-3]["content"]
        else:
            agent_response = agent[-1]["content"]
        response = "\n\n One agent response: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + "\n\n Please summarize the responses from different agents by consolidating the responses from the agents into one response for the given question"
    agent_context = [{"role": "user", "content": prefix_string}]
    completion = generate_answer(agent_context)
    print(completion)
    content = completion["choices"][0]["message"]["content"]

    return content, completion

def construct_sum_debate(content, question):
    prefix_string = f"These are the recent/updated opinions with confidence scores out of 100 from other agents: \n\n{content}"
    prefix_string = prefix_string + "\n\n Using these opinions carefully as additional advice, can you provide an updated answer to the question {}\n\nExplain your answer.".format(question)
    return {"role": "user", "content": prefix_string}

def construct_assistant_message(completion):
    content = completion["choices"][0]["message"]["content"]
    return {"role": "assistant", "content": content}

def generation(alpaca_df, agents, rounds):
    generated_description = {}
    for prompt in tqdm(alpaca_df["prompt"]):
        agent_contexts = [[{"role": "user", "content": prompt + "\n\nExplain your answer. Additionally rank your confidence in your response on a scale from 1-100, 1 being least confident and 100 being most confident."}] for agent in range(agents)]

        content = agent_contexts[0][0]["content"]
        question_prompt = prompt
        try:
            for round in range(rounds):
                for i, agent_context in enumerate(agent_contexts):
                    if round != 0:
                        agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
                        other_agent_indices = list(range(0, i)) + list(range(i+1, agents))
                        sum_content, message = summarize_message(agent_contexts_other, question_prompt, other_agent_indices, i)
                        message = construct_sum_debate(sum_content, question_prompt)
                        agent_context.append(message)
                        print("message: ", message)

                    completion = generate_answer(agent_context)

                    assistant_message = construct_assistant_message(completion)
                    agent_context.append(assistant_message)
                    print(completion)
            generated_description[prompt] = agent_contexts
        except KeyboardInterrupt:
            return generated_description
        except:
            return generated_description
    return generated_description

def create_jsonl(args, description, filename, structure = "final"):
    dataset_ex = []
    for prompt in description:
        agent_contexts = description[prompt]
        if structure == "combined":
            agent_data = {"messages": [prompt for agent in agent_contexts for prompt in agent]}
            dataset_ex.append(agent_data)
        elif structure == "final":
            agent_responses = []
            question = agent_contexts[0][0]
            for agent in agent_contexts:
                final_agent_response = agent[-1]
                agent_responses.append(final_agent_response)
            agent_response = np.random.choice(agent_responses)
            dataset_ex.append({"messages": [question, agent_response]})
        else:
            for agent in agent_contexts:
                agent_data = {"messages": agent}
                dataset_ex.append(agent_data)
    with open(f"{args.save_path}/{filename}.jsonl", "w") as f:
        for entry in dataset_ex:
            json.dump(entry, f)
            f.write("\n")
    return dataset_ex

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", action = "store", type = str, dest = "data_path", default = "alpaca_data.json")
    parser.add_argument("--save_path", action = "store", type = str, dest = "save_path", default = "ft_data")
    parser.add_argument("--dataset_size", action = "store", type = int, default = 5000, dest = "dataset_size")
    parser.add_argument("--structure", type = str, action = "store", default = "final", dest = "structure", choices = ["per", "combined", "final"])
    parser.add_argument("--resample", action = "store_true", dest = "resample")
    parser.add_argument("--reuse", action = "store_true", dest = "reuse")
    parser.set_defaults(resample = False, reuse = False)
    args = parser.parse_args()
    agents = 4 #NOTE: This should be hard coded. Don"t add these as argparse arguments. ONLY CHANGE THIS FOR PURPOSEFUL EXPERIMENTATION
    rounds = 3
    #NOTE: The reuse flag allows us to restructure the data. Having the flag just reopens the raw training and validation responses from the debate.
    if not args.reuse:
        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        if (not os.path.exists(f"{args.save_path}/5k-alpaca-ft.csv")) or (args.resample):
            print(f"Resampling {args.dataset_size} alpaca examples")
            alpaca_df = pd.read_json(args.data_path)
            alpaca_df = alpaca_df.loc[:, ~alpaca_df.columns.str.contains("^Unnamed")]
            debate_df = select_prompts(alpaca_df, sample = args.dataset_size)
            debate_df.to_csv(f"{args.save_path}/5k-alpaca-ft.csv")
        else:
            debate_df = pd.read_csv(f"{args.save_path}/5k-alpaca-ft.csv")
        print(debate_df)
        train_description = generation(debate_df, agents, rounds)
        with open(f"{args.save_path}/5k-alpaca-raw-train.json", "w") as f:
            json.dump(train_description, f)
    else:
        assert os.path.exists(f"{args.save_path}/5k-alpaca-raw-train.json"), "Raw train file MUST exist"
        train_description = json.load(open(f"{args.save_path}/5k-alpaca-raw-train.json", "r"))
        train_dataset = create_jsonl(args, train_description, f"5k-ft-final-train_agents{agents}_rounds{rounds}", args.structure)
    pass