import time
import re
import signal
from contextlib import contextmanager
from time import sleep
import json
from dotenv import load_dotenv
import os
import openai
import codeop
import pandas as pd
from random import sample
from collections import defaultdict
import numpy as np
from enum import Enum
import csv
import gc


# Import my own scripts
from program_formatting import *
from testcase_formatting import *

from dataset_processing_functions import *
from mbpp_processing import *
from humaneval_processing import *


load_dotenv()
openai.api_key = os.environ.get("OPENAI_API_KEY")


# Parameters to change
time_to_sleep = 65

test_case_temp = 1
test_case_max_token_length = 700
test_case_num_completions = 21

program_temp = 1
program_max_token_length = 500
program_num_completions = 29

time_to_timeout = 2


class DataSet(Enum):
    '''
    Dataset used
    '''
    MBPP = 0
    HumanEval = 1


def is_valid_code(line):
    '''
    Sees if a line is syntatically valid.
    '''
    try:
        codeop.compile_command(line)
    except SyntaxError:
        return False
    else:
        return True


def gen_programs(prompt, func_name, signature, assertion_cases, num_times_to_sample, time_to_sleep,
                 program_format=gen_program_using_example_assertion_format, params=[], param_types=[], return_type=""):
    '''
    Generates {program_num_completions} number of programs given a prompt.

    Returns as a list of strings.
    '''
    all_programs = []
    finish_reason = []

    for num in range(num_times_to_sample):
        format = program_format(
            prompt, func_name, assertion_cases, params, param_types, return_type)

        openai_completion = openai.Completion.create(
            model="code-davinci-002",
            prompt=format,
            max_tokens=program_max_token_length,
            temperature=program_temp,
            n=program_num_completions,

            # Stop only allows up to 4 sequences -> missing an "\nassert" and a "\nprint"
            stop=["\ndef", "\n#", "\nclass", "\nif"],
            echo=True
        )

        for i in openai_completion["choices"]:
            rest_of_the_code = i['text']

            all_programs.append(rest_of_the_code)
            finish_reason.append(i["finish_reason"])

        sleep(time_to_sleep)

    return all_programs, finish_reason


def gen_test_cases(prompt, func_name, signature, assertion_cases, num_times_to_sample, time_to_sleep,
                   test_case_format=gen_test_cases_without_example_assertion_format, params=[], param_types=[], return_type=""):
    '''
    Generates {test_case_num_completions} number of test cases given a prompt. In the form of assert statements.

    Returns as a list of strings, format of the prompt sent in.
    '''

    format, beginning = test_case_format(
        prompt, func_name, signature, assertion_cases, params, param_types, return_type)

    all_completed_assertion_cases = []

    for num in range(num_times_to_sample):

        openai_completion = openai.Completion.create(
            model="code-davinci-002",
            prompt=format,
            max_tokens=test_case_max_token_length,
            temperature=test_case_temp,
            n=test_case_num_completions,
            stop=["\ndef", "\nclass", "\nif", "\n#"]
        )

        for i in openai_completion["choices"]:
            # Replace all newline characters for simplicity
            full_completion = f'''{beginning}{i["text"]}'''.replace(
                '\n', ' ')

            full_completion_list = full_completion.split()

            # Processing the generated assertion cases into a list
            generated_assertion_cases = []
            temp = []
            for i in full_completion_list:
                if i == "assert" and temp != []:
                    generated_assertion_cases.append(" ".join(temp))
                    temp = []

                temp.append(i)

            # For the last assertion case - not always finished, so must check manually
            last_case = " ".join(temp)
            if is_valid_code(last_case):
                generated_assertion_cases.append(last_case)

            all_completed_assertion_cases.append(generated_assertion_cases)

        sleep(time_to_sleep)

    return all_completed_assertion_cases, format


def cache_testcases_prompts(program_name="program_log", testcase_name="testcase_log",
                            program_format=gen_program_using_example_assertion_format,
                            test_case_format=gen_test_cases_without_example_assertion_format,
                            num_programs_sampled=1, num_test_cases_sampled=1,
                            time_to_sleep_program=30, time_to_sleep_test_case=30,
                            dataset_used=DataSet.MBPP, total_programs_to_sample=100):
    '''
    Generates programs and test cases from Codex and stores them in csv files.

    Grabs the dataset from func_to_process.

    2 csv files:
    1 - program_log.csv -> Contains generated program
    2 - testcases_log.csv -> Contains generated test case
    '''

    if dataset_used == DataSet.MBPP:
        problems_list = process_mbpp()

        # Sample random problems
        num_total_cases = 974
    elif dataset_used == DataSet.HumanEval:
        problems_list = process_humaneval()

        num_total_cases = 164

    # Saving generated program to a csv

    program_log_columns = ["Prompt id", "Prompt",
                           "Generated Program", "Finish Reason"]

    testcase_log_columns = ["Prompt id", "Prompt", "Generated Test Case"]

    # Setup headers

    program_df = pd.DataFrame(columns=program_log_columns)

    testcase_df = pd.DataFrame(columns=testcase_log_columns)

    program_df.to_csv(f"{program_name}.csv", mode='w', index=False)
    testcase_df.to_csv(f"{testcase_name}.csv", mode='w',
                       index=False)

    random_nums_to_sample = sample(
        range(num_total_cases), total_programs_to_sample)

    # task_id
    # prompt
    # entry_point -> function name
    # canoical_solution -> solution
    # test -> assert cases in the form of a function called 'check' to be called with the function name
    # passed as a parameter
    # prompt, func_name, signature, assertion_cases, num_times_to_sample, time_to_sleep, program_format=gen_program_using_example_assertion_format):

    def cache_program(problem):
        if dataset_used == DataSet.MBPP:
            programs, finish_reasons = gen_programs(problem["text"],
                                                    problem["function_name"], problem["signature"], problem["test_list"],
                                                    program_format=program_format, num_times_to_sample=num_programs_sampled, time_to_sleep=time_to_sleep_program,
                                                    params=problem["parameters"], param_types=problem["parameter_typing"], return_type=problem["return_type"])
            prompt = problem["text"]
        elif dataset_used == DataSet.HumanEval:
            programs, finish_reasons = gen_programs(problem["prompt"],
                                                    "", "", "", program_format=program_format, num_times_to_sample=num_programs_sampled, time_to_sleep=time_to_sleep_program)
            prompt = problem["prompt"]

        num_generated_programs = len(programs)

        # Additional processing for programs

        formatted_program = [
            additional_processing_program(x) for x in programs]

        program_df = pd.DataFrame(
            {'Prompt id': [i for x in range(num_generated_programs)],
             'Prompt': [prompt for i in range(num_generated_programs)],
             'Generated Program': programs,
             'Formatted Generated Program': formatted_program,
             "Finish Reason": finish_reasons})

        program_df.to_csv(f"{program_name}.csv", mode='a',
                          index=False, header=False)

    def cache_test_case(problem):
        if dataset_used == DataSet.MBPP:
            test_cases, format = gen_test_cases(
                problem["text"], problem["function_name"], problem["signature"], problem["test_list"],
                test_case_format=test_case_format, num_times_to_sample=num_test_cases_sampled, time_to_sleep=time_to_sleep_test_case,
                params=problem["parameters"], param_types=problem["parameter_typing"], return_type=problem["return_type"])

            # Problem prompt from dataset
            prompt = problem["text"]
        elif dataset_used == DataSet.HumanEval:
            test_cases, format = gen_test_cases(
                problem["prompt"], problem["entry_point"], "", "", test_case_format=test_case_format, num_times_to_sample=num_test_cases_sampled, time_to_sleep=time_to_sleep_test_case)

            # Problem prompt from datatset
            prompt, _ = humaneval_testcase_formatting(
                problem["prompt"], problem["entry_point"])

        test_cases = list(set([x for xs in test_cases for x in xs]))

        num_generated_testcases = len(test_cases)

        testcase_df = pd.DataFrame(
            {'Prompt id': [i for x in range(num_generated_testcases)],
             'Prompt': [format for i in range(num_generated_testcases)],
             'Generated Test Case': test_cases})

        testcase_df.to_csv(f"{testcase_name}.csv", mode='a',
                           index=False, header=False)

    for i in random_nums_to_sample:
        problem = problems_list[i]

        cache_program(problem)

        cache_test_case(problem)

        # Column formats:
        # program_log_columns = ["Prompt id", "Prompt", "Generated Program"]
        # testcase_log_columns = ["Prompt id", "Prompt", "Generated Test Case"]

# Timeout Handling



def evaluate_program(program, test_case, custom_evaluation=None, get_raw_output=False):
    '''
    Evaluates if a program is correct.

    Takes in str program, str test_case

    If no custom_evaluation provided, formatting is done as
    {program}
    {testcase}

    Else, the evaluated program is the custom evaluation.

    Returns bool result, str error
    '''

    if custom_evaluation is None:
        program = f'''{program}

{test_case}'''
    else:
        program = custom_evaluation

    # result + error of code
    result = False
    type_error = "None"

    # try to execute
    try:
        # If should return raw output
        if get_raw_output:
            result = None

            with timeout(time_to_timeout):
                temp_dict = {}
                exec(program, temp_dict)
                print("SDF")

                result = eval(test_case, temp_dict)
                print("AA")

        # Else, if just executing an assertion test case
        else:
            with timeout(time_to_timeout):
                exec(program, {})
            result = True
    except AssertionError:
        type_error = "AssertionError"
    except SyntaxError:
        type_error = "SyntaxError"
    # Handle timeouts
    except TimeoutError:
        type_error = "TimeoutError"
    except:
        type_error = "MiscError"
    return result, type_error, program


def test_generated_test_cases_ground_truth(testcase_generated_filename="testcase_log", testcase_correct_filename="is_generated_testcase_correct",
                                           dataset=DataSet.MBPP, provided_function_name=True):
    '''
    Tests the generated test case with the ground truth code.

    Loads from '{testcase_generated_filename}.csv' and 'mbpp.csv'.
    Saves to '{testcase_correct_filename}.csv'.
    '''

    # Read from csvs

    generated_testcases = pd.read_csv(f'{testcase_generated_filename}.csv')

    if dataset == DataSet.MBPP:
        dataset_df = pd.read_csv('mbpp.csv')
    elif dataset == DataSet.HumanEval:
        dataset_df = pd.read_csv('humaneval.csv')

    analyzed_testcase_columns = [
        "Prompt id", "Prompt", "Generated Test Case", "Evaluated program", "Correct", "Error"]

    # Setup headers

    analyzed_testcase_df = pd.DataFrame(
        columns=analyzed_testcase_columns)

    analyzed_testcase_df.to_csv(
        f"{testcase_correct_filename}.csv", mode='w', index=False)

    for test_index, testcase in generated_testcases.iterrows():
        prompt_id = testcase["Prompt id"]
        from_dataset_df = dataset_df.iloc[prompt_id]
        if dataset == DataSet.MBPP:
            ground_truth_program = from_dataset_df['code']

            # If different function names used
            if not provided_function_name:

                # Get testcase function name
                testcase_function_name = re.search(
                    '(?<=assert).*(?=())', testcase["Generated Test Case"])
                if not testcase_function_name:
                    print("ERROR!")
                    raise ValueError()
                else:
                    testcase_function_name = testcase_function_name[0].strip()

                # Replace testcase function name with ground truth program function name
                testcase_to_use = testcase["Generated Test Case"].replace(
                    testcase_function_name, from_dataset_df["function_name"])
                result, error, evaluated_program = evaluate_program(
                    ground_truth_program, testcase_to_use)

            # Otherwise just process like normal
            else:
                result, error, evaluated_program = evaluate_program(
                    ground_truth_program, testcase["Generated Test Case"])

        elif dataset == DataSet.HumanEval:
            ground_truth_program = f'''{from_dataset_df["prompt"]}
{from_dataset_df["canonical_solution"]}'''
            result, error, evaluated_program = evaluate_program(
                ground_truth_program, testcase["Generated Test Case"])

        program_df = pd.DataFrame(
            [[prompt_id, testcase["Prompt"], testcase["Generated Test Case"], evaluated_program, result, error]], columns=analyzed_testcase_columns)

        program_df.to_csv(f"{testcase_correct_filename}.csv", mode='a',
                          index=False, header=False)


def additional_processing_program(program):
    '''
    Additionally process programs generated from Codex.

    Current processing:
    1) Replace all "input()" with "1" to ensure automated testing.
    2) Codex can only take in 4 stop tokens => here, additional stop tokens are 
    enforced ('\nassert' and '\nprint').
    '''

    # replace all potential 'input()'s in the code with numbers
    program = re.sub(
        'input([^)]*)', '"1"', program)

    if "\nassert" in program:
        program = program[:program.find("\nassert")]

    if "\nprint" in program:
        program = program[:program.find("\nprint")]
    return program


def test_generated_programs_ground_truth(program_generated_filename="program_log", program_correct_filename="is_generated_program_correct",
                                         dataset=DataSet.MBPP, provided_function_name=True):
    '''
    Tests the generated program with the ground truth test cases.

    Loads from '{program_generated_filename}.csv' and 'mbpp.csv'.
    Saves to '{program_correct_filename}.csv'.
    '''

    # Read from cached csvs
    generated_programs = pd.read_csv(f'{program_generated_filename}.csv')
    if dataset == DataSet.MBPP:
        dataset_df = pd.read_csv('mbpp.csv')
    elif dataset == DataSet.HumanEval:
        dataset_df = pd.read_csv('humaneval.csv')

    # Saving everything to a csv

    analyzed_program_ground_truth_columns = [
        "Prompt id", "Prompt", "Formatted Generated Program", "Evaluated Program", "Correct", "Error"]

    # Setup headers

    analyzed_program_ground_truth_df = pd.DataFrame(
        columns=analyzed_program_ground_truth_columns)

    analyzed_program_ground_truth_df.to_csv(
        f"{program_correct_filename}.csv", mode='w', index=False)

    for program_index, program in generated_programs.iterrows():
        prompt_id = program["Prompt id"]
        from_dataset_df = dataset_df.iloc[prompt_id]

        if dataset == DataSet.MBPP:

            ground_truth_testcases = '\n'.join(
                json.loads(from_dataset_df['test_list']))

            generated_prog = program["Formatted Generated Program"]

            # If different function names used
            if not provided_function_name:

                # If function name is not null
                if not pd.isnull(program["Function name"]):
                    ground_truth_testcases = ground_truth_testcases.replace(
                        from_dataset_df["function_name"], program["Function name"])
                    result, error, evaluated_program = evaluate_program(
                        generated_prog, ground_truth_testcases)

                # If function name is null, should fail
                else:
                    result, error, evaluated_program = False, "FunctionNameError", f'''{generated_prog}

    {ground_truth_testcases}'''

            # Else, just process like normal
            else:
                result, error, evaluated_program = evaluate_program(
                    generated_prog, ground_truth_testcases)

        # HumanEval's dataset is a bit different
        elif dataset == DataSet.HumanEval:

            testing = f'''{program["Formatted Generated Program"]}
{from_dataset_df["test"]}

check({from_dataset_df["entry_point"]})
'''

            result, error, evaluated_program = evaluate_program(
                "", "", custom_evaluation=testing)

        program_df = pd.DataFrame(
            [[prompt_id, program["Prompt"], program["Formatted Generated Program"], evaluated_program, result, error]], columns=analyzed_program_ground_truth_columns)

        program_df.to_csv(f"{program_correct_filename}.csv", mode='a',
                          index=False, header=False)


def test_gen_programs_and_test_cases(testcase_correct_filename="is_generated_testcase_correct",
                                     program_correct_filename="is_generated_program_correct", gen_program_gen_test_case_filename="gen_program_gen_testcase"):
    '''
    Tests all generated program with each generated test case.

    Loads from '{is_generated_program_correct}.csv', '{testcase_correct_filename}.csv', and 'mbpp.csv'.

Saves to '{gen_program_gen_testcase}.csv'.
    '''

    # Read from cached csvs
    generated_programs = pd.read_csv(f'{program_correct_filename}.csv')
    generated_testcases = pd.read_csv(f'{testcase_correct_filename}.csv')

    # Get all possible prompt ids
    generated_prompt_ids = set(generated_programs['Prompt id'])

    # Setup headers to a csv file
    gen_program_gen_testcase_columns = [
        "Prompt id", "Prompt", "Generated Program ID", "Generated Test Case ID", "Evaluated Program",
        "Is generated program correct", "Is generated test case correct", "Result", "Error"]

    pd.DataFrame(columns=gen_program_gen_testcase_columns).to_csv(
        f"{gen_program_gen_test_case_filename}.csv", mode='w', index=False)

    for prompt_id in generated_prompt_ids:
        gen_programs = generated_programs[generated_programs["Prompt id"] == prompt_id]
        gen_testcases = generated_testcases[generated_testcases["Prompt id"] == prompt_id]

        # Iterate thru each program and each test case that corresponds to the prompt id
        for program_index, program in gen_programs.iterrows():
            for testcase_index, testcase in gen_testcases.iterrows():
                prog = additional_processing_program(
                    program["Formatted Generated Program"])

                result, error, evaluated_program = evaluate_program(
                    prog, testcase["Generated Test Case"])

                df = pd.DataFrame(
                    [[prompt_id, program["Prompt"], program_index, testcase_index, evaluated_program,
                     program["Correct"], testcase["Correct"], result, error]], columns=gen_program_gen_testcase_columns)

                df.to_csv(f"{gen_program_gen_test_case_filename}.csv", mode='a',
                          index=False, header=False)

                # Testing results in script "killed" -> perhaps memory leak, so manually callling gc.collect() here
                gc.collect()


def get_input_output(row):
    '''
    Returns a list of Tuples of (input, output) from an assert test case.
    '''
    testcase = row["Generated Test Case"]

    without_assert = re.search('(?<=assert).*$', testcase)[0].strip()
    # print(testcase)
    # Should revisit of having errors => the == could be included as a parameter somehow
    split = [x.strip() for x in without_assert.split("==")]

    if len(split) != 2:
        # Try splitting on "is"
        split = [x.strip() for x in without_assert.split(" is ", 1)]

        if len(split) != 2:
            # Could be "assert func()" without the == or "is"
            function_call = re.search('\(.*', testcase)

            if function_call:
                function_call = function_call[0].strip()
            else:
                # Badly formed
                return pd.Series(["", ""])

            return pd.Series([function_call, True])

    function_call, output = split
    function_call = re.search('\(.*', function_call)
    if function_call:
        function_call = function_call[0].strip()
    else:
        return pd.Series(["", ""])

    return pd.Series([function_call, output])


def get_inputs_from_func_name(row):
    '''
    Generates a list of tuples which represent the inputs.
    '''

    function_name = row["Function name"]
    testcase = row["Generated Test Case"]

    num_bracket, num_paren = 0, 0
    start = end = testcase.find(function_name) + len(function_name)

    has_encountered_start_paren = False

    # If function name not found: return empty string
    if start == -1:
        return pd.Series([""])

    while end < len(testcase) and (num_bracket > 0 or (num_paren > 0 and has_encountered_start_paren) or (not has_encountered_start_paren)):
        if testcase[end] == '(':
            if not has_encountered_start_paren:
                has_encountered_start_paren = True
                start = end
            num_paren += 1
        elif testcase[end] == ')':
            num_paren -= 1
        elif testcase[end] == '[':
            num_bracket += 1
        elif testcase[end] == ']':
            num_bracket -= 1

        end += 1
        if end >= len(testcase) and (num_bracket > 0 or num_paren > 0):
            return pd.Series([""])

    return pd.Series([str(testcase[start:end])])


def gen_alphacode_clustering_model(program_filename="program_no_func_name", testcase_filename="testcase_no_signature", output_program="program_no_func_name_formateted", output_testcase="testcase_no_signature_formatted"):
    '''
    Pipeline in the clustering model of testcases to programs used in the AlphaCode paper.
    Formats the programs by getting the inputs and outputs of each function and test case.

    Reads from {program_filename}.csv and {testcase_filename}.csv.

    Outputs to {output_program}.csv and {output_testcase}.csv
    '''

    program_df = pd.read_csv(f"{program_filename}.csv")
    testcase_df = pd.read_csv(f"{testcase_filename}.csv")

    new_headers_program = program_df.columns.to_list()
    new_headers_program.extend(["Function name", "Num parameters"])

    new_headers_testcase = testcase_df.columns.to_list()
    new_headers_testcase.extend(["Function call", "Output"])

    program_corrected_df = pd.DataFrame(
        columns=new_headers_program)

    testcase_corrected_df = pd.DataFrame(columns=new_headers_testcase)

    program_corrected_df.to_csv(
        f'{output_program}.csv', mode='w', index=False)

    testcase_corrected_df.to_csv(
        f'{output_testcase}.csv', mode='w', index=False)

    # analyzed_testcase_df.to_csv(
    #     f"{testcase_correct_filename}.csv", mode='a', index=False)

    # Get all possible prompt ids
    all_prompt_ids = set(program_df['Prompt id'])

    # Gets function name + number of parameters from a function
    def get_func_name(row):
        program = row["Generated Program"]

        # Gets function name. Not the most exhaustive though, can be beaten if : is in the parameters somehow (think default=":")
        function_signature = re.search(
            '(?<=def).*?(?=(: *(\t)*#*[^(\n)]*\n))', program)

        if function_signature:
            function_signature = function_signature[0].strip()
        else:
            print(repr(program), "\nAFAAA")
            return pd.Series(["", ""])

        function_name = re.search('.*(?=\()', function_signature)
        if function_name:
            function_name = function_name[0]
        else:
            print(program, "\nDDDDD")
            return pd.Series(["", ""])

        parameters = get_parameters(function_signature, program)

        return pd.Series([function_name, len(parameters)])

    for prompt_id in all_prompt_ids:
        programs_for_prompt_id = program_df[program_df['Prompt id'] == prompt_id].copy(
            deep=True)

        programs_for_prompt_id[["Function name", "Num parameters"]] = programs_for_prompt_id.apply(
            get_func_name, axis=1)

        testcases_for_prompt_id = testcase_df[testcase_df['Prompt id'] == prompt_id].copy(
            deep=True)

        testcases_for_prompt_id[["Function call", "Output"]
                                ] = testcases_for_prompt_id.apply(get_input_output, axis=1)

        programs_for_prompt_id.to_csv(
            f'{output_program}.csv', mode='a', index=False, header=False)

        testcases_for_prompt_id.to_csv(
            f'{output_testcase}.csv', mode='a', index=False, header=False)


# This function is for programs and test cases generated without the function name included.
def alphacode_clustering_model_testing(program_filename="program_no_func_name_formatted", testcase_filename="testcase_no_signature_formatted",
                                       output_filename="program_no_func_name_testcase_no_signature_tested"):
    '''
    Pipeline in the clustering model of testcases to programs used in the AlphaCode paper.
    Tests the programs with each testcase.

    Reads from {program_filename}.csv and {testcase_filename}.csv.

    Outputs to {output_filename}.csv
    '''

    program_df = pd.read_csv(f"{program_filename}.csv")
    testcase_df = pd.read_csv(f"{testcase_filename}.csv")

    # Setup headers to a csv file
    testing_columns = [
        "Prompt id", "Prompt", "Generated Program ID", "Generated Test Case ID", "Evaluated Program", "Result", "Error"]

    pd.DataFrame(columns=testing_columns).to_csv(
        f"{output_filename}.csv", mode='w', index=False)

    # Get all possible prompt ids
    all_prompt_ids = set(program_df['Prompt id'])

    for prompt_id in all_prompt_ids:
        programs_for_prompt_id = program_df[program_df['Prompt id'] == prompt_id].copy(
            deep=True)

        testcases_for_prompt_id = testcase_df[testcase_df['Prompt id'] == prompt_id].copy(
            deep=True)

        # Iterate thru each program and each test case that corresponds to the prompt id
        for program_index, program in programs_for_prompt_id.iterrows():
            for testcase_index, testcase in testcases_for_prompt_id.iterrows():
                # Skip iteration if no function name or function call (i.e broken)
                if not program["Function name"] or not testcase["Function call"]:
                    continue

                testcase_assembled = f"assert {program['Function name']}{testcase['Function call']} == {testcase['Output']}"

                result, error, evaluated_program = evaluate_program(
                    program["Generated Program"], testcase_assembled)

                df = pd.DataFrame(
                    [[prompt_id, program["Prompt"], program_index, testcase_index, evaluated_program, result, error]], columns=testing_columns)

                df.to_csv(f"{output_filename}.csv", mode='a',
                          index=False, header=False)

# This function is for programs and test cases generated without the function name included.


def alphacode_testing_format_to_test_gen_programs_and_test_cases(program_testcase_input="mbpp_nothing_provided_program_testcase_formatted_pre",
                                                                 correct_program_input="is_mbpp_program_nothing_provided_formatted_correct",
                                                                 correct_testcase_input="is_mbpp_testcase_nothing_provided_formatted_correct",
                                                                 output_filename="mbpp_nothing_provided_program_testcase_formatted"):
    program_testcase_df = pd.read_csv(f"{program_testcase_input}.csv")
    program_df = pd.read_csv(f"{correct_program_input}.csv")
    testcase_df = pd.read_csv(f"{correct_testcase_input}.csv")

    def get_correct_program_testcase(row):
        is_program_correct = program_df.iloc[row["Generated Program ID"]]
        is_testcase_correct = testcase_df.iloc[row["Generated Test Case ID"]]

        return pd.Series([is_program_correct["Correct"], is_testcase_correct["Correct"]])

    program_testcase_df[["Is generated program correct",
                         "Is generated test case correct"]] = program_testcase_df.apply(get_correct_program_testcase, axis=1)

    program_testcase_df.to_csv(f"{output_filename}.csv")


def my_clustering_model(program_testcase_filename="program_no_func_name_testcase_no_signature_tested", output_filename="no_func_no_signature_"):
    '''
    Uses a clustering model based off of results [T or F] from matching test cases.

    Reads from {program_testcase_filename}.csv.
    '''

    program_testcase_df = pd.read_csv(f"{program_testcase_filename}.csv")
    mbpp_df = pd.read_csv(f"mbpp.csv")

    output_columns = [
        "Prompt id", "Prompt", "Generated Program ID", "Program", "Ground Truth Test Cases", "Ground Truth Program"]
    pd.DataFrame(columns=output_columns).to_csv(
        f"{output_filename}.csv", mode='w', index=False)

    # Get all possible prompt ids
    all_prompt_ids = set(program_testcase_df['Prompt id'])

    for prompt_id in all_prompt_ids:
        row_from_prompt_id = program_testcase_df[program_testcase_df['Prompt id'] == prompt_id]

        generated_program_ids = set(row_from_prompt_id["Generated Program ID"])

        cluster = defaultdict(list)

        for program_id in generated_program_ids:
            program_row = row_from_prompt_id[row_from_prompt_id["Generated Program ID"] == program_id].copy(
                deep=True)

            # Sort each program's results by Test Case ID, then convert to tuple which can be used as a key to a dictionary.
            # Having an order gives the tuple a meaning.
            program_row.sort_values(by="Generated Test Case ID", inplace=True)

            result = tuple(program_row["Result"].to_list())

            # Append the generated program
            cluster[result].append(program_id)

        sorted_cluster = sorted(cluster.items(), key=lambda x: len(x))

        program_id_from_highest_cluster = sorted_cluster[0][1][0]
        row_from_id = row_from_prompt_id[row_from_prompt_id["Generated Program ID"]
                                         == program_id_from_highest_cluster].iloc[0]

        actual_program = row_from_id["Evaluated Program"]

        actual_program = actual_program[:actual_program.rfind('assert') - 1]

        mbpp_row = mbpp_df.iloc[prompt_id]

        df = pd.DataFrame(
            [[prompt_id, row_from_id["Prompt"], row_from_id["Generated Program ID"], actual_program, mbpp_row["test_list"], mbpp_row["code"]]], columns=output_columns)

        df.to_csv(f"{output_filename}.csv", mode='a',
                  index=False, header=False)
        # print(row_from_id, mbpp_row, actual_program)
        # break


def extract_parameters_from_testcase(testcase_filename, output_filename, dataset=DataSet.MBPP):
    '''
    Extracts parameters from a testcase file where the prompt includes the function name.

    Part of a pipeline for a clustering model based off of the AlphaCode paper.

    Reads from {program_testcase_filename}.csv.
    Writes to {output_filename}.csv.
    '''

    # Get df from csvs
    testcase_df = pd.read_csv(f"{testcase_filename}.csv")

    if dataset == dataset.MBPP:
        dataset_df = pd.read_csv("mbpp.csv")

    # Get all possible prompt ids
    all_prompt_ids = set(testcase_df['Prompt id'])

    new_headers_testcase = testcase_df.columns.to_list()
    new_headers_testcase.extend(["Function call", "Output"])

    testcase_corrected_df = pd.DataFrame(columns=new_headers_testcase)

    testcase_corrected_df.to_csv(
        f'{output_filename}.csv', mode='w', index=False)

    for prompt_id in all_prompt_ids:

        # Extract input-output
        testcases_for_prompt_id = testcase_df[testcase_df['Prompt id'] == prompt_id].copy(
            deep=True)

        # Get function name
        testcases_for_prompt_id["Function name"] = dataset_df.iloc[prompt_id]["function_name"]

        testcases_for_prompt_id["Input"] = testcases_for_prompt_id.apply(
            get_inputs_from_func_name, axis=1)

        testcases_for_prompt_id.to_csv(
            f'{output_filename}.csv', mode='a', index=False, header=False, quoting=csv.QUOTE_ALL)


def test_parameters_with_program(program_correct_filename, testcase_correct_with_parameters_formatted_filename, output_filename):
    '''
    Uses parameters generated from extract_parameters_from_testcase as input to the generated programs,
    and gets the resulting outputs.

    Programs with SyntaxErrors are removed.

    Part of a pipeline for a clustering model based off of the AlphaCode paper.

    Writes to {output_filename}.csv.
    '''

    # Get df from csvs
    program_df = pd.read_csv(f"{program_correct_filename}.csv")
    testcase_df = pd.read_csv(
        f"{testcase_correct_with_parameters_formatted_filename}.csv")

    # Setup headers to a csv file
    output_columns = [
        "Prompt id", "Prompt", "Generated Program ID", "Generated Test Case ID", "Evaluated Program",
        "Is generated program correct", "Is generated test case correct", "Output", "Error"]

    pd.DataFrame(columns=output_columns).to_csv(
        f"{output_filename}.csv", mode='w', index=False)

    # Get all possible prompt ids
    all_prompt_ids = set(program_df['Prompt id'])

    for prompt_id in all_prompt_ids:
        # Remove syntax errors
        programs_for_prompt_id = program_df[(program_df['Prompt id'] == prompt_id) & (program_df['Error'] != "SyntaxError")].copy(
            deep=True)

        # Can potentially increase the number of inputs generated by not checking for SyntaxErrors (usually problems w/ SyntaxErrors actually have correctly working inputs)
        testcases_for_prompt_id = testcase_df[(testcase_df['Prompt id'] == prompt_id) & (testcase_df['Error'] != "SyntaxError")].copy(
            deep=True)

        # Iterate thru each program and each test case that corresponds to the prompt id
        for program_index, program in programs_for_prompt_id.iterrows():
            for testcase_index, testcase in testcases_for_prompt_id.iterrows():
                formatted_prog = additional_processing_program(
                    program["Generated Program"])

                formatted_testcase = f"{testcase['Function call']}{testcase['Output']}"

                result, error, evaluated_program = evaluate_program(
                    formatted_prog, formatted_testcase, custom_evaluation=formatted_prog, get_raw_output=True)

                # If result too big to convert to a float => Return nothing
                try:
                    float(result)
                except OverflowError:
                    result = None
                    error = "OverflowError"
                except:
                    pass

                df = pd.DataFrame(
                    [[prompt_id, program["Prompt"], program_index, testcase_index, formatted_prog, formatted_testcase,
                     program["Correct"], result, error]], columns=output_columns)

                df.to_csv(f"{output_filename}.csv", mode='a',
                          index=False, header=False)


def get_stats(filename):
    '''
    Returns statistics for each prompt.

    Reads from {filename}.csv.

    Prints:
    (1) Number of unique prompts
    (2) Stats for number of prompts
    (3) Summary for each generated item
    '''
    print("----------------------------------")
    print(f"Stats for {filename}.csv")
    print("----------------------------------")

    df = pd.read_csv(f"{filename}.csv")

    prompt_ids = set(df['Prompt id'])

    freq = df["Prompt id"].value_counts()
    stats = df.groupby(by=["Prompt id"]).describe()

    print(f"Num unique generated items: {len(prompt_ids)}")
    print(freq.describe())
    print(stats)


if __name__ == "__main__":
    start_time = time.time()
    guard()
    # Pipeline

    # ----------------- Cache MBPP --------------
    # mbpp_jsonl_to_csv()
    # process_mbpp()

    # Testing
    # print(get_parameter_detailed_typing("(2,2)"))
    # print(get_parameter_detailed_typing(
    #     '[(4, 2), (7, 1), (4, 8), (4, 2), (9, 2), (7, 1)] '))
    # print(get_parameter_detailed_typing(
    #     '[(4, 2), (7, 1), (4, 8), (4, 2), (9, 2), (7, 1, 3)] '))
    # print(get_parameter_detailed_typing('[1,4,5] '))
    # print(get_parameter_detailed_typing('{1:"sftes", 2:"adf"} '))
    # print(get_parameter_detailed_typing('[(1,4,5)] '))
    # print(get_parameter_detailed_typing('[(1,4,5, "A"), ("A", "D")] '))
    # print(get_parameter_detailed_typing("{}"))
    # print(get_parameter_detailed_typing("{2,3}"))
    # print(get_parameter_detailed_typing("[{2,3},{1,4}]"))

    # print(get_parameter_detailed_typing('[[(2,2)]]'))

    # Initial processing + caching
    # cache_testcases_prompts(program_name="test_prog_return_type", testcase_name="test_test_return_type",
    #                         program_format=gen_program_return_type_detailed_param_type, test_case_format=gen_test_cases_return_type_detailed_param_type,
    #                         num_test_cases_sampled=1, num_programs_sampled=1, time_to_sleep_program=65, time_to_sleep_test_case=65,
    #                         total_programs_to_sample=50)
    # test_generated_programs_ground_truth(
    #     program_generated_filename="test_prog", program_correct_filename="is_test_prog_cor")

    # print("Pass @ 1", pass_at_k(k=1, program_correct_filename="is_test_prog_cor"))
    # print("Baseline model", baseline_model(
    #     program_correct_filename="is_test_prog_cor"))
    # print("Naive matching", naive_matching(
    #     gen_program_gen_test_case_filename="gen_test_prog"))

    # cache_testcases_prompts(program_name="mbpp_program_no_assertion", testcase_name="mbpp_testcase_no_assertion",
    #                         program_format=gen_program_without_assertion_format, test_case_format=gen_test_cases_without_example_assertion_format,
    #                         num_test_cases_sampled=1, num_programs_sampled=1, time_to_sleep_program=65, time_to_sleep_test_case=65,
    #                         dataset_used=DataSet.MBPP)

    # Tests program w/ the generated test case

    # test_generated_programs_ground_truth(program_generated_filename=program_log_filename, program_correct_filename=program_correct_filename)
    # test_generated_test_cases_ground_truth(testcase_generated_filename=testcase_log_filename, testcase_correct_filename=testcase_correct_filename)
    # test_gen_programs_and_test_cases(testcase_correct_filename=testcase_correct_filename,
    #                                  program_correct_filename=program_correct_filename, gen_program_gen_test_case_filename=gen_program_gen_test_case_filename)

    # Baseline model, effectively pass@1
    # 0.38 for programs generated with an example assertion
    # 0.15 for programs generated without an example assertion
    # print(
    #     f"Baseline model, programs generated without example assertion : {baseline_model(program_correct_filename='is_program_without_example_assertion_correct')}")
    # print(
    #     f"Baseline model, programs generated with example assertion : {baseline_model(program_correct_filename='is_generated_program_correct')}")

    # First model - naive matching. 0.54 for my dataset matching programs with an example assertion + test cases without an example assertion
    # 0.43 for my dataset matching programs without an example assertion to test cases without an example assertion
    # print(
    #     f"First model - naive matching, programs generated without example assertion : {naive_matching(gen_program_gen_test_case_filename='gen_program_gen_test_case_all_without_example_assertion')}")
    # print(
    #     f"First model - naive matching, programs generated with example assertion : {naive_matching(gen_program_gen_test_case_filename='gen_program_gen_testcase')}")

    # Pass @ 5
    # 0.42 accuracy - programs generated without an example assertion
    # 0.68 accuracy - programs generated using an example assertion
    # print(
    #     f"Pass @ k for programs generated without example assertion: {pass_at_k(k=5, program_correct_filename='is_program_without_example_assertion_correct')}")

    # print(
    #     f"Pass @ k for programs generated with example assertion: {pass_at_k(k=5, program_correct_filename='is_generated_program_correct')}")

    # ----------------------- Redo of earlier test where program and test cases generated with only function name --------

    # cache_testcases_prompts(program_name="mbpp_program_no_assertion", testcase_name="mbpp_testcase_no_assertion",
    #                         program_format=gen_program_without_assertion_format, test_case_format=gen_test_cases_without_example_assertion_format,
    #                         num_test_cases_sampled=1, num_programs_sampled=1, time_to_sleep_program=65, time_to_sleep_test_case=65,
    #                         dataset_used=DataSet.MBPP)

    # 0.08 for baseline, 0.23 for naive matching
    # test_generated_programs_ground_truth(
    #     program_generated_filename="mbpp_program_no_assertion", program_correct_filename="is_mbpp_program_no_assertion_correct")
    # test_generated_test_cases_ground_truth(
    #     testcase_generated_filename="mbpp_testcase_no_assertion", testcase_correct_filename="is_mbpp_testcase_no_assertion_correct")
    # test_gen_programs_and_test_cases(testcase_correct_filename="is_mbpp_testcase_no_assertion_correct",
    #                                  program_correct_filename="is_mbpp_program_no_assertion_correct", gen_program_gen_test_case_filename="mbpp_no_assertion_program_testcase")
    # print(
    #     f"Baseline model, programs generated without example assertion : {baseline_model(program_correct_filename='is_mbpp_program_no_assertion_correct')}")

    # print(
    #     f"First model - naive matching, programs generated without example assertion : {naive_matching(gen_program_gen_test_case_filename='mbpp_no_assertion_program_testcase')}")

    # ------------------ Sixth Test (Retry of Second Test)

    # print("Pass @ 1", pass_at_k(k=1,
    #       program_correct_filename="evaluated-datasets/Sixth(Retry of Second Test)/is_mbpp_program_no_assertion_correct"))

    # print("Naive matching", naive_matching(
    #     gen_program_gen_test_case_filename="evaluated-datasets/Sixth(Retry of Second Test)/mbpp_no_assertion_program_testcase"))

    # -------------------------------- HumanEval Pipeline ----------------------

    # Parse HumanEval Dataset
    # os.chdir('prompt-datasets')
    # humaneval_to_csv()

    # Generate HumanEval Completions
    # cache_testcases_prompts(program_name="humaneval_gen_program", testcase_name="humaneval_gen_testcase",
    #                         program_format=humaneval_program_formatting, test_case_format=humaneval_testcase_formatting,
    #                         num_test_cases_sampled=1, num_programs_sampled=1, time_to_sleep_program=65, time_to_sleep_test_case=65,
    #                         dataset_used=DataSet.HumanEval)

    # Test HumanEval Generated Program
    # test_generated_programs_ground_truth(
    #     program_generated_filename="humaneval_gen_program", program_correct_filename="is_humaneval_gen_program_correct", dataset=DataSet.HumanEval)
    # test_generated_test_cases_ground_truth(
    #     testcase_generated_filename="humaneval_gen_testcase", testcase_correct_filename="is_humaneval_gen_testcase_correct", dataset=DataSet.HumanEval)
    # test_gen_programs_and_test_cases(testcase_correct_filename="is_humaneval_gen_testcase_correct",
    #                                  program_correct_filename="is_humaneval_gen_program_correct", gen_program_gen_test_case_filename="humaneval_program_testcase")

    # Models for HumanEval
    # 0.16
    # print(
    #     f"Baseline model, HumanEval: {baseline_model(program_correct_filename='is_humaneval_gen_program_correct')}")
    # 0.37
    # print(
    #     f"Naive matching, HumanEval : {naive_matching(gen_program_gen_test_case_filename='humaneval_program_testcase')}")

    # AlphaCode Clustering Model
    # Hand confirmation of a sample of 30 achieves an accuracy rate of ~0.10
    # gen_alphacode_clustering_model()
    # alphacode_clustering_model_testing()
    # my_clustering_model()

    # --------------- AlphaCode Clustering with Tons of Programs ---
    # gen_alphacode_clustering_model(
    #     output_program="clustering_program_lots_generated", output_testcase="clustering_testcase_lots_generated")
    # alphacode_clustering_model_testing(
    #     program_filename="clustering_program_lots_generated", testcase_filename="clustering_testcase_lots_generated", output_filename="clustering_program_testcase")
    # my_clustering_model()

    # ------------------ Mass amounts of programs and test cases generated ---------

    # Stats
    # 50 programs for each prompt, ~100 test cases for each program
    # get_stats("mbpp_program_nothing_provided")

    # gen_alphacode_clustering_model(program_filename="mbpp_program_nothing_provided", testcase_filename="mbpp_testcase_nothing_provided",
    #                                output_program="mbpp_program_nothing_provided_formatted", output_testcase="mbpp_testcase_nothing_provided_formatted")
    # test_generated_programs_ground_truth(
    #     program_generated_filename="mbpp_program_nothing_provided_formatted", program_correct_filename="is_mbpp_program_nothing_provided_formatted_correct", provided_function_name=False)
    # test_generated_test_cases_ground_truth(
    #     testcase_generated_filename="mbpp_testcase_nothing_provided_formatted", testcase_correct_filename="is_mbpp_testcase_nothing_provided_formatted_correct", dataset=DataSet.HumanEval, provided_function_name=False)

    # alphacode_clustering_model_testing(
    #     program_filename="mbpp_program_nothing_provided_formatted", testcase_filename="mbpp_testcase_nothing_provided_formatted", output_filename="mbpp_nothing_provided_program_testcase_formatted_pre")
    # alphacode_testing_format_to_test_gen_programs_and_test_cases(program_testcase_input="mbpp_nothing_provided_program_testcase_formatted_pre",
    #                                                              correct_program_input="is_mbpp_program_nothing_provided_formatted_correct",
    #                                                              correct_testcase_input="is_mbpp_testcase_nothing_provided_formatted_correct",
    #                                                              output_filename="mbpp_nothing_provided_program_testcase_formatted")

    # 0.16
    # print(
    #     f"Baseline model, programs generated without anything : {baseline_model(program_correct_filename='is_mbpp_program_nothing_provided_formatted_correct')}")

    # get_stats("mbpp_nothing_provided_program_testcase_formatted_pre")

    # 0.347
    # print(
    #     f"Naive matching, programs generated without anything : {naive_matching(gen_program_gen_test_case_filename='mbpp_nothing_provided_program_testcase_formatted')}")

    # alphacode_clustering_model_testing(
    #     program_filename="clustering_program_lots_generated", testcase_filename="clustering_testcase_lots_generated", output_filename="clustering_program_testcase")
    # my_clustering_model()

    # ----------------- Lots of programs with parameter types included (no return types) --------------
    # get_stats("lots_progs")
    # get_stats("lots_test")

    # test_generated_programs_ground_truth(
    #     program_generated_filename="lots_progs", program_correct_filename="is_lots_progs_correct", dataset=DataSet.HumanEval)
    # test_generated_test_cases_ground_truth(
    #     testcase_generated_filename="lots_test", testcase_correct_filename="is_lots_test_correct", dataset=DataSet.HumanEval)
    # test_gen_programs_and_test_cases(testcase_correct_filename="is_lots_test_correct",
    #                                  program_correct_filename="is_lots_progs_correct", gen_program_gen_test_case_filename="lots_progs_tests")

    # extract_parameters_from_testcase(
    #     "is_lots_test_correct", "lots_tests_clustering")

    test_parameters_with_program(
        "is_lots_progs_correct", "lots_tests_clustering", "final_clustering_lots")

    print(f"Program took: {round(time.time() - start_time, 2)} seconds")
