import pandas as pd
import numpy as np
import random


def naive_matching(gen_program_gen_test_case_filename, num_test_cases_to_solve=3, output_filename=""):
    '''
    First model. Naively selects the program that solves the most of the test cases that are solved by the most programs.

    Selects the first {num_test_cases_to_solve} most solved test cases to be regarded as the "ground truth test cases",
    and the first program to solve all {num_test_cases_to_solve} test cases is selected.

    If no program solves all {num_test_cases_to_solve} test cases, then the program that solves the most is selected.
    If there's a tie in the number of test cases solved, then the "rank" of each test case is computed (how many programs
    solve it), and the program with the highest rank sum is chosen.

    Reads from '{gen_program_gen_test_case_filename}.csv'.

    Writes to '{output_filename}.csv'.

    Returns {int accuracy}
    '''
    if not output_filename:
        output_filename = f"naive_matching.csv"

    # Reads from cachced csv
    gen_everything = pd.read_csv(gen_program_gen_test_case_filename)

    # Get all possible prompt ids
    generated_prompt_ids = set(gen_everything['Prompt id'])

    output_columns = ["Prompt id", "Prompt",
                      "Generated Program ID", "Program", "Is generated program correct"]
    pd.DataFrame(columns=output_columns).to_csv(
        output_filename, mode='w', index=False)

    for prompt_id in generated_prompt_ids:
        gen_from_prompt_id = gen_everything[gen_everything["Prompt id"] == prompt_id]

        most_solved_test_cases = gen_from_prompt_id.groupby(
            by="Generated Test Case ID")["Result"].sum().sort_values(ascending=False)[:num_test_cases_to_solve]

        test_case_set = set(most_solved_test_cases.index)

        programs_from_test_cases = gen_from_prompt_id[gen_from_prompt_id["Generated Test Case ID"].isin(
            test_case_set)].copy(deep=True)

        # Create a new column with a value based on how widespread the test cases were - this is used for a
        # tiebreaker scenario where two programs have solved the same number of test cases but different test cases
        programs_from_test_cases["Rank"] = programs_from_test_cases["Generated Test Case ID"].apply(
            lambda id: most_solved_test_cases[id])

        def rank(df):
            new_df = {}

            new_df["Result"] = df["Result"].sum()

            new_df["Rank"] = df[df["Result"]]["Rank"].sum()

            return pd.Series(new_df)

        programs_solving_tests = programs_from_test_cases.groupby(
            by=["Generated Program ID", "Is generated program correct"]).apply(rank)

        programs_solving_tests = programs_solving_tests.sort_values(
            ['Result', 'Rank'], ascending=[False, False])

        best_program_id, is_program_correct = programs_solving_tests.index[0]

        row_from_program_chosen = gen_from_prompt_id[gen_from_prompt_id["Generated Program ID"]
                                                     == best_program_id].iloc[0]

        df_to_output = pd.DataFrame(
            [[prompt_id, row_from_program_chosen["Problem Prompt"], row_from_program_chosen["Generated Program ID"],
              row_from_program_chosen["Evaluated Program"], row_from_program_chosen["Is generated program correct"]]], columns=output_columns)

        df_to_output.to_csv(f"{output_filename}.csv", mode='a',
                            index=False, header=False)

    output_model = pd.read_csv(f"{output_filename}.csv")

    return output_model["Is generated program correct"].sum()/output_model.shape[0]


def pass_at_k(k, program_correct_filename):
    '''
    pass@k from first k selected programs.

    Reads from '{program_correct_filename}.csv'.

    Returns {int accuracy}.
    '''

    def unbiased_estimate(n, c, k):
        """
        :param n: total number of samples
        :param c: number of correct samples
        :param k: k in pass@$k$
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    # Read from cached csv
    generated_programs = pd.read_csv(program_correct_filename)

    # Get all possible prompt ids
    generated_prompt_ids = set(generated_programs['Prompt id'])

    num_correct = 0
    for prompt_id in generated_prompt_ids:
        progs_from_prompt_id = generated_programs[generated_programs["Prompt id"] == prompt_id]

        num_correct_samples = progs_from_prompt_id["Correct"].sum()

        num_correct += unbiased_estimate(len(progs_from_prompt_id),
                                         num_correct_samples, k)
    return num_correct/len(generated_prompt_ids)


def num_toxic_programs(program_correct_filename):
    df = pd.read_csv(program_correct_filename)

    prompt_set = set(df["Prompt id"])

    num_total = len(prompt_set)
    num_skipped = 0
    for prompt_id in prompt_set:
        prompt_df = df[df["Prompt id"] == prompt_id]
        # if all false
        if (~prompt_df["Correct"]).all():
            print(prompt_df["Correct"])
            num_skipped += 1

    return num_skipped, num_total


def num_positives(formatted_df_filename):
    df = pd.read_csv(formatted_df_filename)

    prompt_set = set(df["Prompt id"])

    num_positives = 0

    for prompt_id in prompt_set:
        prompt_df = df[df["Prompt id"] == prompt_id]

        num_positives += prompt_df["Is generated program correct"].any()

    print("Num positives ", num_positives)


def thresholded_pass_at_k(k, score_whole_df_filename, threshold):
    '''
    Thresholded pass_at_k.

    Returns (pass_at_k, precision, recall, num_problems_skipped, num_positives/num_total).

    If all the problems are skipped, returns 1
    '''
    df = pd.read_csv(score_whole_df_filename)

    prompt_set = set(df["Prompt id"])

    num_total = len(prompt_set)

    num_skipped = 0
    num_correct = 0

    num_positives = 0

    for prompt_id in prompt_set:
        prompt_df = df[df["Prompt id"] == prompt_id]

        num_positives += prompt_df["Is generated program correct"].any()

        prompt_df = prompt_df[prompt_df["Trustworthy Score"]
                              >= threshold].sort_values(by=['Trustworthy Score', 'Generated Program ID'], ascending=[False, True])

        prompt_df = prompt_df.iloc[:k]

        if len(prompt_df) == 0:
            num_skipped += 1
        else:
            num_correct += prompt_df["Is generated program correct"].any()

    precision = num_correct / \
        (num_total - num_skipped) if num_skipped != num_total else 1
    recall = num_correct / num_positives

    print(num_positives)

    return (num_correct / num_total, precision, recall, num_skipped, num_positives/num_total)


def thresholded_pass_at_k_by_cluster(k, score_whole_df_filename, threshold):
    '''
    Thresholded pass_at_k.

    Returns (pass_at_k, precision, recall, num_problems_skipped).

    If all the problems are skipped, returns 1
    '''
    df = pd.read_csv(score_whole_df_filename)

    prompt_set = set(df["Prompt id"])

    num_total = len(prompt_set)

    num_skipped = 0
    num_correct = 0

    num_positives = 0

    for prompt_id in prompt_set:
        prompt_df = df[df["Prompt id"] == prompt_id]

        num_positives += (
            prompt_df["Proportion of correct generated program"] > 0).any()

        prompt_df = prompt_df[prompt_df["Trustworthy Score"]
                              >= threshold].sort_values(by=['Trustworthy Score', 'Cluster size rank'], ascending=[False, True])

        prompt_df = prompt_df.iloc[:k]

        if len(prompt_df) == 0:
            num_skipped += 1
        else:
            num_correct += prompt_df["Proportion of correct generated program"].mean()

    precision = num_correct / \
        (num_total - num_skipped) if num_skipped != num_total else 1
    recall = num_correct / num_positives

    return (num_correct / num_total, precision, recall, num_skipped)


def pass_at_k_skipping_toxic_programs(k, program_correct_filename):
    '''
    pass@k from first k selected programs.

    Reads from '{program_correct_filename}.csv'.

    Returns {int accuracy}.
    '''

    def unbiased_estimate(n, c, k):
        """
        :param n: total number of samples
        :param c: number of correct samples
        :param k: k in pass@$k$
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    # Read from cached csv
    generated_programs = pd.read_csv(program_correct_filename)

    # Get all possible prompt ids
    generated_prompt_ids = set(generated_programs['Prompt id'])

    num_correct = 0
    num_total = 0
    for prompt_id in generated_prompt_ids:
        progs_from_prompt_id = generated_programs[generated_programs["Prompt id"] == prompt_id]

        if (~progs_from_prompt_id["Correct"]).all():
            print(progs_from_prompt_id["Correct"])
            continue

        num_correct_samples = progs_from_prompt_id["Correct"].sum()

        num_correct += unbiased_estimate(len(progs_from_prompt_id),
                                         num_correct_samples, k)
        num_total += 1

    return num_correct/num_total


def codet_model(k, input_filename):
    '''
    Model from the codeT paper.
    '''
    # def apply_codet()

    whole_df = pd.read_csv(input_filename)

    def cluster_size(df):
        return len(
            prompt_df[prompt_df["Cluster size rank"] == df["Cluster size rank"]])

    def testcases_passed(df):

        first_loc = df.iloc[0]

        first_loc["% test cases passed"] = df["Result"].sum()

        return first_loc

    num_correct = 0

    problem_set = set(whole_df["Prompt id"])
    for prompt_id in problem_set:
        prompt_df = whole_df[whole_df["Prompt id"]
                             == prompt_id].copy(deep=True)

        prompt_df = prompt_df.groupby(
            "Generated Program ID").apply(testcases_passed)

        prompt_df["Cluster size num"] = prompt_df.apply(cluster_size, axis=1)

        prompt_df["Score"] = np.sqrt(
            prompt_df["Cluster size num"]) * prompt_df["% test cases passed"]

        prompt_df.sort_values("Score", ascending=False, inplace=True)

        num_correct += prompt_df.iloc[:k]["Is generated program correct"].any()

    return num_correct/len(problem_set)


def get_top_n_testcases_to_df(whole_df_filename, formatted_df_filename, program_df_filename, testcase_df_filename, output_df_filenamae, n):
    '''
    Takes in {whole_dataset.csv}, {formatted_df_scores.csv}, {correct program_df_filename},
    {correct testcase df filename}, {output_df_filename}, {n test cases}

    Saves the highest scoring program for each prompt id along with n distinguishing test cases to each file
    '''

    whole_df = pd.read_csv(whole_df_filename)
    formatted_df = pd.read_csv(formatted_df_filename)

    program_df = pd.read_csv(program_df_filename)
    testcase_df = pd.read_csv(testcase_df_filename)

    columns = ['Selected Program', 'Is Program Correct'] + \
        [f'Selected Test Case {i + 1}' for i in range(n)]

    pd.DataFrame(columns=columns).to_csv(output_df_filenamae, index=False)

    for prompt_id in set(formatted_df["Prompt id"]):

        prompt_df = formatted_df[formatted_df["Prompt id"] == prompt_id]

        prompt_df = prompt_df.groupby(
            ["Generated Program ID"]).first().reset_index()
        prompt_df.sort_values(
            by=['Trustworthy Score', 'Generated Program ID'], ascending=False, inplace=True)

        most_prob_generated_id = prompt_df.iloc[0]["Generated Program ID"]

        most_prob_generated_df = whole_df[whole_df["Generated Program ID"]
                                          == most_prob_generated_id]

        most_prob_generated_testcases_id = list(most_prob_generated_df[most_prob_generated_df["Result"] == True].sort_values(
            by="Programs solved")["Generated Test Case ID"].iloc[:n])

        selected_program = [
            program_df.iloc[most_prob_generated_id]["Formatted Generated Program"], program_df.iloc[most_prob_generated_id]["Correct"]]
        selected_test_cases = list(
            testcase_df.iloc[most_prob_generated_testcases_id]['Generated Test Case'])

        print(selected_test_cases)
        print(selected_program)

        res = selected_program + selected_test_cases

        pd.DataFrame([res]
                     ).to_csv(output_df_filenamae, mode='a', header=False, index=False)


def get_random_n_testcases_to_df(whole_df_filename, formatted_df_filename, program_df_filename, testcase_df_filename, output_df_filenamae, n):
    '''
    Takes in {whole_dataset.csv}, {formatted_df_scores.csv}, {correct program_df_filename},
    {correct testcase df filename}, {output_df_filename}, {n test cases}

    Saves the highest scoring program for each prompt id along with n distinguishing test cases to each file
    '''

    whole_df = pd.read_csv(whole_df_filename)
    formatted_df = pd.read_csv(formatted_df_filename)

    program_df = pd.read_csv(program_df_filename)
    testcase_df = pd.read_csv(testcase_df_filename)

    columns = ['Selected Program', 'Is Program Correct'] + \
        [f'Selected Test Case {i + 1}' for i in range(n)]

    pd.DataFrame(columns=columns).to_csv(output_df_filenamae, index=False)

    for prompt_id in set(formatted_df["Prompt id"]):

        prompt_df = formatted_df[formatted_df["Prompt id"] == prompt_id]

        prompt_df = prompt_df.groupby(
            ["Generated Program ID"]).first().reset_index()
        prompt_df.sort_values(
            by=['Trustworthy Score', 'Generated Program ID'], ascending=False, inplace=True)

        most_prob_generated_id = prompt_df.iloc[0]["Generated Program ID"]

        most_prob_generated_df = whole_df[whole_df["Generated Program ID"]
                                          == most_prob_generated_id]

        random_generated_testcases_id = most_prob_generated_df[
            most_prob_generated_df["Result"] == True]["Generated Test Case ID"]

        random_generated_testcases_id = list(
            random_generated_testcases_id.iloc[random.sample(range(0, len(random_generated_testcases_id)), n)])

        selected_program = [
            program_df.iloc[most_prob_generated_id]["Formatted Generated Program"], program_df.iloc[most_prob_generated_id]["Correct"]]
        selected_test_cases = list(
            testcase_df.iloc[random_generated_testcases_id]['Generated Test Case'])

        print(selected_test_cases)
        print(selected_program)

        res = selected_program + selected_test_cases

        pd.DataFrame([res]
                     ).to_csv(output_df_filenamae, mode='a', header=False, index=False)


def get_top_n_testcases(whole_df_filename, formatted_df_filename, program_df_filename, testcase_df_filename, n, k):
    '''
    Takes in {whole_dataset.csv}, {formatted_df_scores.csv}, {correct program_df_filename},
    {correct testcase df filename}, {n test cases}, {k prompt id}

    Returns the highest scoring program along with n distinguishing test cases
    '''

    whole_df = pd.read_csv(whole_df_filename)
    formatted_df = pd.read_csv(formatted_df_filename)

    program_df = pd.read_csv(program_df_filename)
    testcase_df = pd.read_csv(testcase_df_filename)

    prompt_df = formatted_df[formatted_df["Prompt id"] == k]

    prompt_df = prompt_df.groupby(
        ["Generated Program ID"]).first().reset_index()
    prompt_df.sort_values(
        by=['Trustworthy Score', 'Generated Program ID'], ascending=False, inplace=True)

    most_prob_generated_id = prompt_df.iloc[0]["Generated Program ID"]

    most_prob_generated_df = whole_df[whole_df["Generated Program ID"]
                                      == most_prob_generated_id]

    most_prob_generated_testcases_id = list(most_prob_generated_df[most_prob_generated_df["Result"] == True].sort_values(
        by="Programs solved")["Generated Test Case ID"].iloc[:n])

    print("Program:")
    print(program_df.iloc[most_prob_generated_id]
          ["Formatted Generated Program"])

    generated_test_cases = list(
        testcase_df.iloc[most_prob_generated_testcases_id]['Generated Test Case'])

    print("Test cases:")
    for i in generated_test_cases:
        print(i)
