import openai
import time
import csv
import os
import json

def generate_powerset(x):
    powerset = [[]]

    for num in range(x):
        new_subsets = []
        for subset in powerset:
            new_subsets.append(subset + [num])
        powerset.extend(new_subsets)

    return powerset


def turn_into_input_format(role, prompt):
    # role: "system", "user", or "assistant"
    return {"role": role, \
            "content": prompt}


def query_problem(args, system_prompt, alternative, model="gpt-4", max_tokens=500, temperature=1, top_p=1, n=1, verbose=False):
    # alternative is the user prompt

    with open(args.credential_path, "r") as f:
        openai.api_key = json.load(f)["openai"]

    if system_prompt is None:
        messages = [alternative]
    else:
        messages = [system_prompt, alternative] # first message in the chain of messages
    if verbose:
        print("prompt:\n", messages)
    
    completion = None
    while completion is None:
        try:
            completion = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p, 
                n=n
            )
        except Exception as e:
            print(e, 'retry')
            time.sleep(5)
    
    if verbose:
        for response in [(choice["message"]["content"]) for choice in completion["choices"]]:
            print("response:\n", response, "\n\n")

    return [completion["choices"]]


def save_as_csv(list_of_dict, path):
    if os.path.exists(path):
        print(f"{path} already exists. Ignoring save_as_csv.")
        return

    print(list_of_dict)
    keys = list_of_dict[0].keys()
    with open(path, 'w') as output_file:
        dict_writer = csv.DictWriter(output_file, keys)
        dict_writer.writeheader()
        dict_writer.writerows(list_of_dict)

def test_save_as_csv():
    a = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
    save_as_csv(a, "test.csv")

# test_save_as_csv()

def add_numerical_list_string_to_list(list, sentences):
    for sentence in sentences.split("\n"):
        if len(sentence) == 0: continue
        if sentence[0].isdigit():
            sentence = sentence[2:].strip()
            list.append(sentence)
    return list


# used when getting the comparison results
def load_responses(directory_path, sorted=True):
    responses = []
    mapping = {}
    files = os.listdir(directory_path)
    if sorted:
        files.sort()
    for filename in files:
        if filename.endswith(".json"):
            with open(os.path.join(directory_path, filename), "r") as f:
                responses.append(json.load(f))
            mapping[len(responses) - 1] = filename
    return responses, mapping


def remap(path):
    with open(f"{path}.json", "r") as f:
        mappings = json.load(f)
    
    result = []
    for mapping in mappings:
        remap = {}
        for key in mapping:
            remap[mapping[key]] = key
        result.append(remap)
    
    with open(f"{path}_remapped.json", "w") as f:
        json.dump(result, f)


def parse_answer(args, answer):
    answer = answer.strip().lower()

    # requeried
    if answer.startswith("requery 1"): 
        return 1
    if answer.startswith("requery 2"):
        return 0
    if answer.startswith("requery tie"):
        return 0.5
    
    # remove reasoning
    answer_text = answer.split("answer:")[-1].strip()

    # simple answer
    if answer_text.startswith("scenario 1\n") or answer_text.startswith("scenario 1.") or answer_text == "scenario 1":
        return 1
    if answer_text.startswith("scenario 2\n") or answer_text.startswith("scenario 2.") or answer_text == "scenario 2":
        return 0
    
    # win by count
    scenario_1_count = answer_text.count("scenario 1") + answer_text.count("scenario [1]") + answer_text.count("scenario one")
    scenario_2_count = answer_text.count("scenario 2") + answer_text.count("scenario [2]") + answer_text.count("scenario two")
    scenario_1_or_2_count = answer_text.count("scenario 1 or 2") + answer_text.count("scenario [1] or [2]") + answer_text.count("scenario one or two")
    scenario_2_count += scenario_1_or_2_count
    scenario_1_and_2_count = answer_text.count("scenario 1 and 2") + answer_text.count("scenario [1] and [2]") + answer_text.count("scenario one and two")
    scenario_2_count += scenario_1_and_2_count

    if scenario_1_count != 0 and scenario_2_count == 0:
        return 1
    if scenario_1_count == 0 and scenario_2_count != 0:
        return 0
    if scenario_1_count > scenario_2_count:
        return 1 # These are basically completely correct heuristics (out of all comparisons in a scenario, maybe one was questionable)
    if scenario_1_count < scenario_2_count:
        return 0
    
    # neither is mentioned at all
    if scenario_1_count == 0:
        return 0.5 # This heuristic works fine as well
    
    # unknown, requery
    return requery(args, answer_text)


def requery(args, answer_text):
    with open("requery.txt", "r") as f:
        requery_text = f.read()
        requery_text = requery_text.replace("ANSWER", answer_text)
    requery_text = turn_into_input_format("user", requery_text)

    requery_answer = query_problem(args, None, requery_text, temperature=1, top_p=1, n=1, verbose=args.verbose)[0][0]["message"]["content"]
    requery_answer = requery_answer.strip().lower()

    if requery_answer == "scenario 1": return 1
    if requery_answer == "scenario 2": return 0
    if requery_answer == "tie": return 0.5
    
    raise ValueError(f"Requery answer is not scenario 1, 2, or tie: \n{answer_text}\n{requery_answer}")