import openai
import random
import sys

from task_utils import createTrainingData, loadDataFromFile, extractQuestionsWithAnswers, extractQuestionsWithAnswersFromBigBench, scoreSolution

openai.api_type = ""
openai.api_base = ""
openai.api_version = ""
openai.api_key = ""

TASK = sys.argv[1]
INITIAL_MODEL = sys.argv[2]
MODEL = sys.argv[3]
K_TOP = int(sys.argv[4])
N_INITIAL_SAMPLES = int(sys.argv[5])
CONTINUE_TRAINING = True if sys.argv[6] == "continue" else False

USE_CUSTOM_INSTRUCTION = False
N_SAMPLES = 1
N_ITERATIONS = 20000
TEMPERATURE = 1.0
TEST_TEMPERATURE = 0
TRAIN_DATA_SIZE = 20
project_path = ""

if TASK == "logical_deduction":
    custom_instruction = "First decide what will the item scores mean, then show step-by-step thinking on how to assign or reassign the scores so that all requirements are satisfied. Show all those steps. Then check your work by repeating the process until everything is satisfied."
    compare_options = True
    full_json_files = [
        project_path + "data/BIG-bench/bigbench/benchmark_tasks/logical_deduction/five_objects/task.json",
    ]
    bbh_json_files = [
        project_path + "data/BIG-Bench-Hard/bbh/logical_deduction_three_objects.json",
        project_path + "data/BIG-Bench-Hard/bbh/logical_deduction_five_objects.json",
        project_path + "data/BIG-Bench-Hard/bbh/logical_deduction_seven_objects.json"
    ]
else:
    custom_instruction = ""
    compare_options = False
    full_json_files = [project_path + "data/BIG-bench/bigbench/benchmark_tasks/%s/task.json" % TASK]
    bbh_json_files = [project_path + "data/BIG-Bench-Hard/bbh/%s.json" % TASK]

enumerate_choices = True
start_of_question_string = "Q:"

if not USE_CUSTOM_INSTRUCTION:
    custom_instruction = ""
model_name = MODEL.replace("-", "")
if INITIAL_MODEL != MODEL:
    model_name = INITIAL_MODEL.replace("-", "") + "+" + MODEL.replace("-", "")
output_file = "%soutputs/%s_%s_gibbs_k%d_sample%d_temp%.1f.txt" % (project_path, TASK, model_name, K_TOP, N_INITIAL_SAMPLES, TEMPERATURE)
if USE_CUSTOM_INSTRUCTION:
    output_file = "%soutputs/%s_human+%s_gibbs_k%d_sample%d_temp%.1f.txt" % (project_path, TASK, model_name, K_TOP, N_INITIAL_SAMPLES, TEMPERATURE)

full_datasets = []
for full_json_file in full_json_files:
    full_data = loadDataFromFile(full_json_file)
    full_datasets.append(extractQuestionsWithAnswersFromBigBench(full_data, enumerate_choices=enumerate_choices, permute_choices=(not compare_options), custom_instruction=custom_instruction))
min_size = min(len(dataset) for dataset in full_datasets)
full_qa_pairs = []
for dataset in full_datasets:
    full_qa_pairs += random.sample(dataset, min_size)

test_qa_pairs = []
for bbh_json_file in bbh_json_files:
    bbh_data = loadDataFromFile(bbh_json_file)
    test_qa_pairs += extractQuestionsWithAnswers(bbh_data, custom_instruction=custom_instruction)

test_subset = test_qa_pairs
train_subset = createTrainingData(full_qa_pairs, test_qa_pairs, TRAIN_DATA_SIZE, compare_options=compare_options)

def constructPrompt(prompt_samples, question):
    if not prompt_samples:
        return question
    else:
        prompt = "\n".join([q + "\n" + a + "\nEND\n" if q else "" for q, a, _, _ in prompt_samples])
        prompt += question + "\n"
        return prompt

def sampleResponse(prompt, model=MODEL, temperature=TEMPERATURE):
    response = None
    while not response:
        try:
            response = openai.Completion.create(
                engine=model,
                prompt=prompt,
                temperature=temperature,
                max_tokens=500,
                top_p=0.5,
                frequency_penalty=0,
                presence_penalty=0,
                stop=["END"])
        except:
            response = None
    response_text = response["choices"][0]["text"].strip()
    return response_text

def sampleSolutions(model, prompt_samples, question, answer, p_rejection=0.95, p_zeroshot=0.1):
    solutions = []
    for i in range(N_SAMPLES):
        if random.random() < p_zeroshot:
            prepend_samples = []
            current_model = INITIAL_MODEL
        else:
            prepend_samples = prompt_samples
            current_model = model
        prompt = constructPrompt(prepend_samples, question)
        solution = sampleResponse(prompt, model=current_model, temperature=TEMPERATURE)
        score = scoreSolution(solution, answer)
        if score == 1 or random.random() > p_rejection:
            solutions.append((solution, prepend_samples, score))
            break
    return solutions

def test(test_pairs, prompt_samples, temperature=TEMPERATURE, debug=False):
    test_qa_scores = []
    for q, a in test_pairs:
        prompt = constructPrompt(prompt_samples, q)
        solution = sampleResponse(prompt, temperature=temperature)
        score = scoreSolution(solution, a)
        test_qa_scores.append((q, solution, score))
        if debug:
            print(q)
            print(solution)
            print("<correct>" + a + "</correct>")
            print(score)
            print()
    scores = [score for _, _, score in test_qa_scores]
    avg_score = sum(scores) / len(scores)
    return avg_score

def makeKey(q, index):
    return str(index) + q

def removeOptionsFromQuestion(q):
    output_question = ""
    for line in q.split("\n"):
        if line.startswith("Options:"):
            break
        output_question += line
    return output_question


question_solution_scores = {}
if CONTINUE_TRAINING:
    train_subset = []
    train_set = {removeOptionsFromQuestion(q): a for q, a in full_qa_pairs}
    with open(output_file, "r") as f:
        current_iter = 0
        current_state = ""
        current_question = ""
        current_output = ""
        for line in f.readlines():
            if line.startswith("ITER"):
                current_iter = int(line.split()[-1][:-1])
            elif line.startswith("Score:"):
                current_state = ""
                if current_question and current_output:
                    key = makeKey(current_question, random.randint(0, N_INITIAL_SAMPLES - 1))
                    q = removeOptionsFromQuestion(current_question)
                    if q in train_set:
                        answer = train_set[q]
                        if (current_question, answer) not in train_subset:
                            train_subset.append((current_question, answer))
                        score = scoreSolution(current_output, answer)
                        question_solution_scores[key] = (current_question, current_output, [], score)
                    current_question, current_output = "", ""
            elif line.startswith(start_of_question_string):
                current_state = "Q"
                current_question = ""
                current_question += line
            elif line.startswith("OUTPUT:"):
                current_state = "O"
                current_output = ""
            elif line.endswith("Finally generate END at the end of the solution.\n"):
                current_state = ""
                current_question += line
            else:
                if current_state == "Q":
                    current_question += line
                elif current_state == "O":
                    current_output += line

    start_iter = current_iter
    # Redirect output to a file
    fout = open(output_file, "a")
    sys.stdout = fout
else:
    # Redirect output to a file
    fout = open(output_file, "w")
    sys.stdout = fout
    start_iter = 0
    # initialize samples
    model = INITIAL_MODEL
    for q, a in train_subset:
        for i in range(N_INITIAL_SAMPLES):
            key = makeKey(q, i)
            question_solution_scores[key] = ("", "", [], 0)
            solutions = sampleSolutions(model, [], q, a)
            if solutions:
                solution, prepend_samples, score = solutions[0]
                question_solution_scores[key] = (q, solution, prepend_samples, score)
                print(q)
                print(solution)

for k in range(start_iter, N_ITERATIONS):
    print("ITER %d:" % k)
    model = MODEL
    samples = list(question_solution_scores.values())
    prompt_samples = random.sample(samples, k=K_TOP)
    prompt_question_set = set([q for q, _, _, _ in prompt_samples])

    q, a = "", ""
    while not q or q in prompt_question_set:
        q, a = random.choice(train_subset)
    solutions = sampleSolutions(model, prompt_samples, q, a)
    if solutions:
        key = makeKey(q, random.randint(0, N_INITIAL_SAMPLES - 1))
        solution, prepend_samples, score = solutions[0]
        question_solution_scores[key] = (q, solution, prepend_samples, score)
        print("PROMPT:")
        for sample in prepend_samples:
            print(sample[0])
            print(sample[1])
        print(q)
        print("OUTPUT:")
        print(solution)
    
    scores = [x[-1] for x in question_solution_scores.values()]
    avg_score = sum(scores) / len(scores)
    print("Score: %.3f" % avg_score)
    print()

prompt_samples = []
for sample in question_solution_scores.values():
    q, solution, prepend_samples, _ = sample
    score = test(train_subset, [(q, solution, [], 0)])
    prompt_samples.append((q, solution, prepend_samples, score))
prompt_samples.sort(key=lambda x: -x[-1])
prompt_samples = prompt_samples[:K_TOP]
print("TEST PROMPT")
for q, solution, _, score in prompt_samples:
    print(q)
    print(solution)
    print(score)

print("TESTING")
avg_score = test(test_subset, prompt_samples, temperature=TEST_TEMPERATURE, debug=True)
print("Test score: %.3f" % avg_score)
fout.close()