import torch
import pandas as pd
import numpy as np
from collections import defaultdict
from random import sample
from sklearn.preprocessing import StandardScaler
import pickle
from pathlib import Path

small_num_to_prevent_underflow = 0.0001

features = ["% test cases passed rank",  "% test cases passed", "Log % test cases passed",

            "Log Cluster size rank", "Log % test case passed rank",

            "Cluster size rank", "Cluster size (out of total programs)", "Log Cluster size (out of total programs)", "Entropy"]


formalspecs_features = ['% test cases passed rank [Formal Specs]', '% test cases passed [Formal Specs]', 'Log % test cases passed [Formal Specs]',

                        'Log Cluster size rank [Formal Specs]', 'Log % test case passed rank [Formal Specs]',

                        'Cluster size rank [Formal Specs]', 'Cluster size (out of total programs) [Formal Specs]', 'Log Cluster size (out of total programs) [Formal Specs]', "Entropy [Formal Specs]"]


# features = features + formalspecs_features


def get_cluster(df):

    all_prompt_ids = set(df["Prompt id"])

    for prompt_id in all_prompt_ids:
        print("Prompt id", prompt_id)
        df_from_prompt = df[df["Prompt id"] == prompt_id]

        program_id_set = set(df_from_prompt["Generated Program ID"])

        cluster = defaultdict(set)

        for program_id in program_id_set:
            program_df = df_from_prompt[df_from_prompt["Generated Program ID"] == program_id].copy(
                deep=True)

            # Sort each program's results by Test Case ID, then convert to tuple which can be used as a key to a dictionary.
            # Having an order gives the tuple a meaning.
            program_df.sort_values(
                by="Generated Test Case ID", inplace=True)

            result = tuple(program_df["Result"].to_list())
            # Append the generated program
            cluster[result].add(program_id)

        # Sort by cluster size
        sorted_cluster = sorted(list(cluster.items()),
                                key=lambda x: len(x[1]), reverse=True)

        # Get ground truth program cluster
        ground_truth = df_from_prompt.groupby(
            by="Generated Test Case ID").first().reset_index()
        ground_truth.sort_values(by="Generated Test Case ID", inplace=True)
        ground_truth = tuple(
            ground_truth["Is generated test case correct"].to_list())
        ground_truth_cluster = cluster[ground_truth]

        for counter, i in enumerate(sorted_cluster):
            if ground_truth_cluster == i[1]:
                is_ground_truth_cluster = True
            else:
                is_ground_truth_cluster = False

            df.loc[df["Generated Program ID"].isin(
                i[1]), ["Cluster size", "Cluster size rank", "Is ground truth cluster"]] = len(i[1])/len(program_id_set), counter + 1, is_ground_truth_cluster

        # Sort by % test cases passed rank
        sorted_cluster = sorted(list(cluster.items()),
                                key=lambda x: sum(x[0]), reverse=True)

        for counter, i in enumerate(sorted_cluster):
            df.loc[df["Generated Program ID"].isin(
                i[1]), ["% test cases passed rank", "% test cases passed"]] = counter + 1, sum(i[0])


def get_alphacode_cluster(df, alphacode_clustering_df_filename):
    clustering_df = pd.read_csv(alphacode_clustering_df_filename)

    clustering_df.drop(['Prompt id'], inplace=True, axis=1)

    df = df.merge(
        clustering_df, on="Generated Program ID")

    all_prompt_ids = set(df["Prompt id"])

    for prompt_id in all_prompt_ids:
        df_from_prompt = df[df["Prompt id"] == prompt_id]

        program_id_set = set(df_from_prompt["Generated Program ID"])

        cluster = defaultdict(set)

        for program_id in program_id_set:
            program_df = df_from_prompt[df_from_prompt["Generated Program ID"] == program_id].copy(
                deep=True)

            # add program id to dict based on # of test cases passed
            cluster[sum(program_df["Result"])].add(program_id)

        # Sort by % test cases passed
        sorted_cluster = sorted(list(cluster.items()),
                                key=lambda x: x[0], reverse=True)
        print(prompt_id)

        for counter, i in enumerate(sorted_cluster):
            df.loc[df["Generated Program ID"].isin(
                i[1]), ["% test cases passed rank", "% test cases passed"]] = counter + 1, i[0]
    return df


def test_case_ranking(df):

    all_prompt_ids = set(df["Prompt id"])

    for prompt_id in all_prompt_ids:

        print("test case ranking", prompt_id)
        df_from_prompt = df[df["Prompt id"] == prompt_id].copy(deep=True)

        test_cases = df_from_prompt.groupby(by="Generated Test Case ID")[
            "Result"].sum().sort_values(ascending=False)

        ids = test_cases.index.tolist()
        programs_solved = test_cases.tolist()

        for count, test_case_id in enumerate(ids):
            df.loc[df["Generated Test Case ID"] == test_case_id, [
                "Programs solved", "Test Case Rank"]] = [programs_solved[count], count + 1]


def remove_syntax_errors(df, program_df, testcase_df):
    print(len(df))
    non_syntax_error_testcase = testcase_df.iloc[df["Generated Test Case ID"]
                                                 ]["Error"] != "SyntaxError"
    df = df[non_syntax_error_testcase.tolist()]

    non_syntax_error_program = program_df.iloc[df["Generated Program ID"]
                                               ]["Error"] != "SyntaxError"

    df = df[non_syntax_error_program.tolist()]

    return df


def scale_data(train_X, test_X, linearcomparison_X=None, filepath='standardscaler'):
    '''
    Scales data for SGD.

    Saves to {filepath}
    '''

    scaler = StandardScaler()

    if linearcomparison_X is not None:
        scaler.fit(linearcomparison_X)
        scaled_training = [torch.tensor(scaler.transform(X))
                           for X in train_X]
        scaled_test = [torch.tensor(scaler.transform(X))
                       for X in test_X] if test_X else None

    else:
        scaler.fit(train_X)
        scaled_training = torch.tensor(scaler.transform(train_X))
        scaled_test = torch.tensor(
            scaler.transform(test_X)) if test_X else None

    with open(filepath, 'wb') as f:
        pickle.dump(scaler, f)

    return scaled_training, scaled_test


def load_scaled_data(X, scaler_path, is_for_logreg):
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)

    if is_for_logreg:
        scaled_X = torch.tensor(scaler.transform(X))
    else:
        scaled_X = [torch.tensor(scaler.transform(X))
                    for X in X]

    return scaled_X


def transform_dataset_by_program(program_testcase_df, output_folder_path, small_num):
    def helper(df):
        first_row = df.iloc[0]
        # "Is ground truth cluster": first_row["Is ground truth cluster"],
        feature_dict = {
            "Prompt id": first_row["Prompt id"],
            "Is generated program correct": first_row["Is generated program correct"],

            'Cluster size rank': first_row["Cluster size rank"],
            'Log Cluster size rank': np.log(first_row["Cluster size rank"] + small_num),

            "% test cases passed rank": first_row["% test cases passed rank"],
            'Log % test case passed rank': np.log(first_row["% test cases passed rank"] + small_num),

            "Log % test cases passed": np.log(df["Result"].sum()/len(df) + small_num),
            "% test cases passed": df["Result"].sum()/len(df),

            "Cluster size (out of total programs)": first_row["Cluster size"],
            "Log Cluster size (out of total programs)": np.log(first_row["Cluster size"] + small_num)
        }

        return pd.Series(feature_dict)

    program_features = program_testcase_df.groupby(
        by=["Generated Program ID"]).apply(helper).reset_index()

    program_features.to_csv(
        output_folder_path / "formatted_for_tensor_conversion_by_program_id.csv", index=False)

    return program_features


def transform_dataset(df_filename, program_filename, testcase_filename, output_folder_path, small_num, features,
                      is_for_logreg, cached_filename, should_alphacode_cluster_filename=None,
                      whole_dataset_filename="whole_dataset.csv", formatted_for_tensor_conversion_filename="formatted_for_tensor_conversion.csv",
                      should_test_case_info=True):
    if cached_filename is None:
        program_testcase_df = pd.read_csv(df_filename)
        program_df = pd.read_csv(program_filename)
        test_df = pd.read_csv(testcase_filename)

        program_testcase_df = remove_syntax_errors(
            program_testcase_df, program_df, test_df)

        if should_alphacode_cluster_filename:
            program_testcase_df = get_alphacode_cluster(program_testcase_df,
                                                        should_alphacode_cluster_filename)

        else:
            get_cluster(program_testcase_df)

        if should_test_case_info:
            test_case_ranking(program_testcase_df)

        program_testcase_df.to_csv(
            output_folder_path / whole_dataset_filename, index=False)

        if is_for_logreg:

            def helper(df):
                first_row = df.iloc[0]

                # "Is ground truth cluster": first_row["Is ground truth cluster"],
                feature_dict = {
                    "Is generated program correct": first_row["Is generated program correct"],

                    "Cluster size rank": first_row["Cluster size rank"],
                    'Log Cluster size rank': np.log(first_row["Cluster size rank"] + small_num),


                    "% test cases passed rank": first_row["% test cases passed rank"],
                    'Log % test case passed rank': np.log(first_row["% test cases passed rank"] + small_num),


                    "Log % test cases passed": np.log(df["Result"].sum()/len(df) + small_num),
                    "% test cases passed": df["Result"].sum()/len(df),

                    "Cluster size (out of total programs)": first_row["Cluster size"],
                    "Log Cluster size (out of total programs)": np.log(first_row["Cluster size"] + small_num)
                }

                return pd.Series(feature_dict)

            # Sort by Generated Program ID + Prompt id
            program_features = program_testcase_df.groupby(
                by=["Prompt id", "Generated Program ID"]).apply(helper).reset_index()

            print("Adding entropy.")

            # Add entropy
            program_ids = set(program_features["Prompt id"])

            for i in program_ids:
                program_df = program_features[program_features["Prompt id"] == i]
                entropy = 0

                for cluster_id in set(program_df["Cluster size rank"]):
                    first = program_df[program_df["Cluster size rank"]
                                       == cluster_id].iloc[0]
                    entropy += - \
                        first["Cluster size (out of total programs)"] * \
                        first["Log Cluster size (out of total programs)"]

                program_features.loc[program_features["Prompt id"]
                                     == i, "Entropy"] = entropy

        else:
            def helper(df):
                first_row = df.iloc[0]

                # "Is ground truth cluster": first_row["Is ground truth cluster"],
                feature_dict = {
                    "Proportion of correct generated program": df["Is generated program correct"].sum()/len(df),

                    'Log Cluster size rank': np.log(first_row["Cluster size rank"] + small_num),


                    "% test cases passed rank": first_row["% test cases passed rank"],
                    'Log % test case passed rank': np.log(first_row["% test cases passed rank"] + small_num),


                    "Log % test cases passed": np.log(df["Result"].sum()/len(df) + small_num),
                    "% test cases passed": df["Result"].sum()/len(df),

                    "Cluster size (out of total programs)": first_row["Cluster size"],
                    "Log Cluster size (out of total programs)": np.log(first_row["Cluster size"] + small_num)
                }

                return pd.Series(feature_dict)

            program_features = program_testcase_df.groupby(
                by=["Prompt id", "Cluster size rank"]).apply(helper).reset_index()

        program_features.to_csv(
            output_folder_path / formatted_for_tensor_conversion_filename, index=False)
    else:
        program_features = pd.read_csv(cached_filename)

    X_list = []
    y_list = []

    # Training test split
    prompt_set = set(program_features["Prompt id"].astype(int))

    # Maintain problem order
    problem_list = []

    for prompt_id in prompt_set:
        prog_feat = program_features[program_features["Prompt id"]
                                     == prompt_id].reset_index()

        X_tensor = torch.from_numpy(
            prog_feat[features].values)

        X_list.append(X_tensor)

        # if for log-reg, y values are if the program are correct
        if is_for_logreg:
            gen_program_status = prog_feat["Is generated program correct"] == True
            y_list.append(torch.from_numpy(gen_program_status.values).double())
        else:
            y_list.append(torch.from_numpy(
                prog_feat["Proportion of correct generated program"].values))
        problem_list.append(prompt_id)

    return X_list, y_list, prompt_set, problem_list


def feature_extraction_no_test_split(df_filename, program_filename, testcase_filename, output_folder_path, scaler_path,  features,
                                     small_num=small_num_to_prevent_underflow, cached_filename=None, is_for_logreg=True, for_data_loader=True, should_return_problem_list=False, should_load_scale_data=True):
    '''
    Extracts all features from a given formatted dataset with executed Codex-generated programs and test-cases.

    For inference; contains no test split.

    '''

    X_list, y_list, _, problem_list = transform_dataset(df_filename, program_filename,
                                                        testcase_filename, output_folder_path, small_num, features, is_for_logreg, cached_filename)

    if for_data_loader:
        # Concat everything
        X = torch.empty(0, len(features))
        y = torch.empty(0)

        for X_tensor in X_list:
            X = torch.cat((X, X_tensor), 0)

        for y_tensor in y_list:
            print(y_tensor.shape, y.shape)
            y = torch.cat((y, y_tensor), 0)
    else:
        X = X_list
        y = y_list

    if should_load_scale_data:
        X, y = load_scaled_data(X, y, scaler_path, for_data_loader)

    if should_return_problem_list:
        return X, y, problem_list

    return X, y

# def logreg_k_fold_feature_extraction(df_filename, program_filename, testcase_filename, output_folder_path, scaler_path, cached_filename=None, test_split=0.8, small_num=0.0001, k_fold = 5):
#     '''
#     Formatting feature extraction for logistic regressor. Uses k-fold cross validation split into {k_fold} number of folds.

#     Returns a list of {k_fold} length corresponding to different portions of the data.
#     '''


def feature_extraction(df_filename, program_filename, testcase_filename, output_folder_path, scaler_path, cached_filename=None, test_split=0.8, small_num=small_num_to_prevent_underflow, is_for_logreg=True):
    '''
    Extracts all features from a given formatted dataset with executed Codex-generated programs and test-cases.

    2 types of extraction:
    1 - Logistic regressor [Returns if program is correct as y, returns Tensors]
    2 - Linear comparison [Returns ground truth cluster as y, retursn list of Tensors]

    2 options for validation:
    1 - Training/Test split

    Saves scaler for data.
    '''

    X_list, y_list, prompt_set, _ = transform_dataset(df_filename, program_filename,
                                                      testcase_filename, output_folder_path, small_num, features, is_for_logreg, cached_filename)

    # with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    #     print(program_features[program_features["Prompt id"] == 1])

    # Get all X as a list of tensors corresponding to each prompt id

    # Split to training + test sets
    prompt_len_set = set(range(len(prompt_set)))

    training_prompts = set(
        sample(prompt_len_set, int(len(prompt_set) * test_split)))
    test_prompts = prompt_len_set - training_prompts

    training_X_list = [X_list[i] for i in training_prompts]
    training_y_list = [y_list[i] for i in training_prompts]

    test_X_list = [X_list[i] for i in test_prompts]
    test_y_list = [y_list[i] for i in test_prompts]

    # If want to return as a tensor
    if is_for_logreg:

        # Convert everything into one big tensor
        train_X = torch.empty(0, len(features))
        train_y = torch.empty(0)
        test_X = torch.empty(0, len(features))
        test_y = torch.empty(0)

        for X_tensor in training_X_list:
            train_X = torch.cat((train_X, X_tensor), 0)

        for y_tensor in training_y_list:
            train_y = torch.cat((train_y, y_tensor), 0)

        for X_tensor in test_X_list:
            test_X = torch.cat((test_X, X_tensor), 0)

        for y_tensor in test_y_list:
            test_y = torch.cat((test_y, y_tensor), 0)

        # Scale data
        train_X, test_X = scale_data(
            train_X, test_X, filepath=scaler_path / "standardscaler_logreg")

        return train_X, train_y, test_X, test_y
    else:

        # Get whole training X for standardization
        train_X = torch.empty(0, len(features))

        for X_tensor in training_X_list:
            train_X = torch.cat((train_X, X_tensor), 0)

        # torch.save((test_X_list, test_y_list),
        #            output_folder_path / 'test_X_and_y.pt')

        train_X, test_X = scale_data(
            training_X_list, test_X_list, filepath=scaler_path / "standardscaler_linearcomparison", linearcomparison_X=train_X)

        return (training_X_list, training_y_list, test_X_list, test_y_list)


def get_alphacode_clustering(alphacode_clustering_filename, output_filename, small_num=small_num_to_prevent_underflow):
    df = pd.read_csv(alphacode_clustering_filename)

    output_df = pd.DataFrame(
        columns=["Prompt id", "Generated Program ID", "Cluster size", "Cluster size rank"])

    f = open(output_filename, 'w')

    output_df.to_csv(f)

    all_prompt_ids = set(df["Prompt id"])

    for prompt_id in all_prompt_ids:
        df_from_prompt = df[df["Prompt id"] == prompt_id]

        program_id_set = set(df_from_prompt["Generated Program ID"])

        cluster = defaultdict(set)

        for program_id in program_id_set:
            program_df = df_from_prompt[df_from_prompt["Generated Program ID"] == program_id].copy(
                deep=True)

            # Sort each program's results by Test Case ID, then convert to tuple which can be used as a key to a dictionary.
            # Having an order gives the tuple a meaning.
            program_df.sort_values(
                by="Generated Test Case ID", inplace=True)

            result = tuple(program_df["Output"].to_list())
            # Append the generated program
            cluster[result].add(program_id)

        # Sort by cluster size
        sorted_cluster = sorted(list(cluster.items()),
                                key=lambda x: len(x[1]), reverse=True)

        for counter, i in enumerate(sorted_cluster):
            for gen_prompt_id in i[1]:
                # Counter + 1 b/c ranking should start from 1
                temp_df = pd.DataFrame(
                    [[prompt_id, gen_prompt_id, len(i[1])/len(program_id_set), counter + 1]])

                temp_df.to_csv(f, mode='a', header=False)
    f.close()


def merge_formal_specs_normal_test_cases(formal_specs_formatted_filename, normal_formatted_filename,
                                         output_df_filename="formatted_for_tensor_conversion_combined_formal_specs.csv"):
    ''' 
    Merges the 'formatted_for_tensor_conversion' files from the formal specs and the test cases.
    '''

    formal_df = pd.read_csv(formal_specs_formatted_filename)
    normal_df = pd.read_csv(normal_formatted_filename)

    formal_df.drop(['Prompt id', 'Is generated program correct'],
                   inplace=True, axis=1)

    rename_dict = {}

    for i in range(len(features)):
        rename_dict[features[i]] = formalspecs_features[i]

    formal_df.rename(columns=rename_dict, inplace=True)

    print(formal_df)

    df = normal_df.merge(formal_df, on="Generated Program ID")

    df.to_csv(output_df_filename)
