import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import json

# Import my scripts
import dataset_metrics
import logreg_model as logreg
import linearcomparison_model as linearcomparison
import data_transformations
import pathlib

device = "cuda" if torch.cuda.is_available() else "cpu"
features = data_transformations.features


def plot_train_test(train, test, output_path):

    fig, (ax1, ax2) = plt.subplots(1, 2)

    ax1.plot(train)
    ax1.set_title("Training loss")

    ax2.plot(test)
    ax2.set_title("Validation loss")

    fig.savefig(output_path)


def visualize_from_kfold_result(results):
    '''
    kfold results give a list of threshold results for each kfold

    Each threshold result is in the form:
    (num_skipped, num_correct, total_num_test_prompts)
    '''

    thresholds = np.arange(0, 1.05, 0.05)

    num_folds = len(results)
    num_thresholds = len(results[0])

    total = []
    for i in range(num_thresholds):
        prec = 0
        rec = 0
        per_skipped = 0

        for j in results:
            print(j, "JJJJ")
            prec += j[i][0]
            rec += j[i][2]
            per_skipped += j[i][1]/j[-1][1]

        prec /= num_folds
        rec /= num_folds
        per_skipped /= num_folds

        total.append((prec, rec, per_skipped))

    print(total)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(9, 4))

    ax1.plot(thresholds, [x[0] for x in total])
    ax1.set_title("Precision")

    ax2.plot(thresholds, [x[1] for x in total])
    ax2.set_title("Recall")

    ax3.plot(thresholds, [x[2] for x in total])
    ax3.set_title("Percent of num skipped")

    fig.savefig("images/mbpp_kfold_loss_non_clustering")


def train_weights_logreg_kfold(formatted_for_tensor_conversion_path, scaler_path, features,  score_path, k=10, diff_formatted_for_tensor_conversion_path=None, use_mlp=False):
    # folder_path = path / "humaneval_0.8_whole_dataset"
    # model_path = path / 'models'
    # scaler_path = path / 'scalers'

    # X, y, problem_list = data_transformations.feature_extraction_no_test_split(folder_path / 'prog_test.csv',
    #                                                                            folder_path / 'is_humaneval_temp_0.8_whole_set_programs_correct.csv', folder_path /
    #                                                                            'is_humaneval_temp_0.8_whole_set_test_cases_correct.csv',
    #                                                                            folder_path, scaler_path,
    #                                                                            cached_filename=folder_path / "formatted_for_tensor_conversion.csv",
    #                                                                            for_data_loader=False, should_load_scale_data=False, should_return_problem_list=True)

    X, y, problem_list = data_transformations.feature_extraction_no_test_split(None, None, None, None, scaler_path, features,
                                                                               cached_filename=formatted_for_tensor_conversion_path,
                                                                               for_data_loader=False, should_load_scale_data=False, should_return_problem_list=True)

    kfold = KFold(n_splits=k, shuffle=True)

    for fold, (train_ids, test_ids) in enumerate(kfold.split(problem_list)):
        # print(fold, train_ids.shape, test_ids.shape)

        # # Print
        # print(f'FOLD {fold}')
        # print('--------------------------------')

        # # Sample elements randomly from a given list of ids, no replacement.
        # train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        # test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

        fold_train_X = []
        fold_train_y = []
        for i in train_ids:
            fold_train_X.append(X[i])
            fold_train_y.append(y[i])

        # fold_test_X = [X[i] for i in test_ids]
        # fold_test_y = [y[i] for i in test_ids]

        test_problem_ids = set([problem_list[i] for i in test_ids])

        # Concat everything
        concat_train_X = torch.empty(0, len(features))
        concat_train_y = torch.empty(0)

        for X_tensor in fold_train_X:
            concat_train_X = torch.cat((concat_train_X, X_tensor), 0)

        for y_tensor in fold_train_y:
            concat_train_y = torch.cat((concat_train_y, y_tensor), 0)

        print(concat_train_X.shape, concat_train_y.shape)

        if use_mlp:
            model = logreg.MLP(concat_train_X.size(dim=1))
        else:
            model = logreg.LogReg(concat_train_X.size(dim=1))

        scaler_kfold = scaler_path / 'scaler_kfold_temp'

        concat_train_X, _ = data_transformations.scale_data(
            concat_train_X, None, filepath=scaler_kfold)
        # print("Shapes", concat_train_X.shape, concat_train_y.shape)

        training_dataset = logreg.DatasetFromProgramTestCaseTesting(
            concat_train_X, concat_train_y)
        train_dataloader = DataLoader(
            training_dataset, batch_size=64, shuffle=True)

        # test_dataset = logreg.DatasetFromProgramTestCaseTesting(
        #     X_test, y_test)
        # test_dataloader = DataLoader(
        #     test_dataset, batch_size=64, shuffle=True)

        # training_loss = []
        # test_loss = []

        epochs = 1500

        learning_rate = 0.001

        optimizer = torch.optim.Adam(
            model.parameters(), lr=learning_rate, weight_decay=1e-4)

        for t in range(epochs):
            loss_train, _ = logreg.train(
                train_dataloader, model, nn.BCELoss(), optimizer)
            # training_loss.append(loss_train)
            # loss_test = logreg.test(test_dataloader, model, nn.BCELoss())
            # test_loss.append(loss_test)

            # writer.add_scalar('Loss/train', loss_train, t)
            # writer.add_scalar('Loss/test', loss_test, t)
            if t % 20 == 0:
                print(f"Epoch {t+1}\n-------------------------------")
                print(
                    f"Training loss: {loss_train:.5f}")
        # print(model.linear.weight)

        print(fold, "fold")

        # Use a different file to test on - for stuff like temperature testing
        if diff_formatted_for_tensor_conversion_path:
            logreg_model_score(diff_formatted_for_tensor_conversion_path,
                               model, scaler_kfold, score_path, features, fold == 0, test_prompts=test_problem_ids)
            print("Saved to model! Used diff path.")
        else:
            logreg_model_score(formatted_for_tensor_conversion_path,
                               model, scaler_kfold, score_path, features, fold == 0, test_prompts=test_problem_ids)
            print("Saved to model!")


def train_weights_linearcomparison_kfold(prog_test_path, prog_correct_path, testcase_correct_path, folder_path,
                                         scaler_path, formatted_for_tensor_conversion_path, score_path, k=10):
    X, y, problem_list = data_transformations.feature_extraction_no_test_split(prog_test_path,
                                                                               prog_correct_path, testcase_correct_path,
                                                                               folder_path, scaler_path,
                                                                               cached_filename=formatted_for_tensor_conversion_path, is_for_logreg=False,
                                                                               for_data_loader=False, should_load_scale_data=False, should_return_problem_list=True)

    kfold = KFold(n_splits=k, shuffle=True)

    for fold, (train_ids, test_ids) in enumerate(kfold.split(problem_list)):
        fold_train_X = [X[i] for i in train_ids]
        fold_train_y = [y[i] for i in train_ids]

        fold_test_X = [X[i] for i in test_ids]
        fold_test_y = [y[i] for i in test_ids]

        test_problem_ids = set([problem_list[i] for i in test_ids])

        # Concat everything
        concat_train_X = torch.empty(0, len(features))

        for X_tensor in fold_train_X:
            concat_train_X = torch.cat((concat_train_X, X_tensor), 0)

        model = linearcomparison.LinearComparison(len(features))

        epochs = 1500
        learning_rate = 0.001

        scaler_kfold = scaler_path / 'scaler_kfold_temp'
        fold_train_X, _ = data_transformations.scale_data(
            fold_train_X, None, linearcomparison_X=concat_train_X, filepath=scaler_kfold)

        optimizer = torch.optim.Adam(
            model.parameters(), lr=learning_rate, weight_decay=1e-4)

        print(fold_train_y)

        for t in range(epochs):
            train_loss = linearcomparison.train(
                fold_train_X, fold_train_y, model, optimizer)
            if t % 20 == 0:
                print(f"Epoch {t+1}\n-------------------------------")
                print(
                    f"Training loss: {train_loss:.5f}")

        print(model.linear.weight)

        linearcomparison_model_kfold(formatted_for_tensor_conversion_path,
                                     model, scaler_kfold, score_path, test_prompts=test_problem_ids)
        print("Saved to model!")


def train_weights_logreg(formatted_for_tensor_conversion_path, scaler_path, scaler_name, features, model_output_path):
    X, y, problem_list = data_transformations.feature_extraction_no_test_split(None, None, None, None, scaler_path, features,
                                                                               cached_filename=formatted_for_tensor_conversion_path,
                                                                               for_data_loader=False, should_load_scale_data=False, should_return_problem_list=True)

    fold_train_X = []
    fold_train_y = []
    for i in range(len(problem_list)):
        fold_train_X.append(X[i])
        fold_train_y.append(y[i])

    # Concat everything
    concat_train_X = torch.empty(0, len(features))
    concat_train_y = torch.empty(0)

    for X_tensor in fold_train_X:
        concat_train_X = torch.cat((concat_train_X, X_tensor), 0)

    for y_tensor in fold_train_y:
        concat_train_y = torch.cat((concat_train_y, y_tensor), 0)

    print(concat_train_X.shape, concat_train_y.shape)
    model = logreg.LogReg(concat_train_X.size(dim=1))

    scaler_kfold = scaler_path / scaler_name

    concat_train_X, _ = data_transformations.scale_data(
        concat_train_X, None, filepath=scaler_kfold)
    # print("Shapes", concat_train_X.shape, concat_train_y.shape)

    training_dataset = logreg.DatasetFromProgramTestCaseTesting(
        concat_train_X, concat_train_y)
    train_dataloader = DataLoader(
        training_dataset, batch_size=64, shuffle=True)

    # test_dataset = logreg.DatasetFromProgramTestCaseTesting(
    #     X_test, y_test)
    # test_dataloader = DataLoader(
    #     test_dataset, batch_size=64, shuffle=True)

    # training_loss = []
    # test_loss = []

    epochs = 2000

    learning_rate = 0.001

    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-4)

    for t in range(epochs):
        loss_train, _ = logreg.train(
            train_dataloader, model, nn.BCELoss(), optimizer)
        # training_loss.append(loss_train)
        # loss_test = logreg.test(test_dataloader, model, nn.BCELoss())
        # test_loss.append(loss_test)

        # writer.add_scalar('Loss/train', loss_train, t)
        # writer.add_scalar('Loss/test', loss_test, t)
        if t % 20 == 0:
            print(f"Epoch {t+1}\n-------------------------------")
            print(
                f"Training loss: {loss_train:.5f}")

    torch.save(model, model_output_path)


def logreg_model_combined_with_naive_matching_method(input_filename, model, output_folder_path, test_prompts=None, num_test_cases_to_solve=3, thresholds_to_test=[0.5]):
    whole_df = pd.read_csv(input_filename)

    path = pathlib.Path()

    # If testing by specific test prompts
    if test_prompts is None:
        test_prompts = set(whole_df["Prompt id"])
    else:
        whole_df = whole_df[whole_df["Prompt id"].isin(
            test_prompts)].copy(deep=True)

    formatted_df = data_transformations.transform_dataset_by_program(
        whole_df, output_folder_path, 0.0001)

    features = ["% test cases passed rank",  "% test cases passed", "Log % test cases passed",

                "Cluster size rank", "Cluster size (out of total programs)", "Log Cluster size (out of total programs)"]

    # First apply model
    for_input = torch.from_numpy(
        formatted_df[features].values).float()

    for_input, _ = data_transformations.load_scaled_data(
        for_input, None, path / 'scalers' / 'standardscaler_logreg', True)

    # Run model
    pred = model(for_input)

    softmax_pred = nn.Softmax(dim=1)(pred)

    results = []

    formatted_df.drop(features + ['Is ground truth cluster',
                                  'Is generated program correct'], inplace=True, axis=1)

    for th in thresholds_to_test:
        print(th)
        copy_df = whole_df.copy(deep=True)
        # Threshold the prediction
        selected_pred = (softmax_pred[:, 1] > th).long()

        pred_np = selected_pred.detach().numpy()

        formatted_df["Is predicted to be right"] = pred_np

        copy_df = copy_df.merge(
            formatted_df, on="Generated Program ID")

        # Only get the ones predicted to be correct
        copy_df = copy_df[copy_df["Is predicted to be right"] == 1]

        num_correct = 0
        num_skipped = 0

        # Then sort by test case and get the best program

        for prompt in test_prompts:
            prompt_df = copy_df[copy_df["Prompt id"] ==
                                prompt]

            # If no problems predicted correct - skip
            if len(prompt_df) == 0:
                num_skipped += 1
                continue

            testcase_df = prompt_df.groupby(
                by="Generated Test Case ID").first().reset_index()

            test_cases = set(testcase_df.sort_values(
                by="Programs solved", ascending=False)[:3]["Generated Test Case ID"])

            def rank(df):
                new_df = {}

                new_df["Result"] = df["Result"].sum()

                new_df["Rank"] = df["Programs solved"].sum()

                return pd.Series(new_df)

            most_solved_test_case_df = prompt_df[prompt_df["Generated Test Case ID"].isin(
                test_cases)].copy(deep=True)

            program_sorted_df = most_solved_test_case_df.groupby(
                by="Generated Program ID").apply(rank).sort_values(
                by=['Result', 'Rank'], ascending=[False, False]).reset_index()

            most_solved = int(
                program_sorted_df.iloc[0]["Generated Program ID"])
            num_correct += prompt_df[prompt_df["Generated Program ID"]
                                     == most_solved].iloc[0]["Is generated program correct"]

        results.append((num_correct/(len(test_prompts) - num_skipped) if (len(test_prompts) -
                       num_skipped) != 0 else 1,  num_skipped, num_correct/len(test_prompts)))

    return results


def logreg_model_save_score(input_filename, model, output_folder_path, scaler_path, output_filename_formatted_df, test_prompts=None):
    whole_df = pd.read_csv(input_filename)

    # If testing by specific test prompts
    if test_prompts is None:
        test_prompts = set(whole_df["Prompt id"])
    else:
        whole_df = whole_df[whole_df["Prompt id"].isin(
            test_prompts)].copy(deep=True)

    formatted_df = data_transformations.transform_dataset_by_program(
        whole_df, output_folder_path, 0.0001)

    # First apply model
    for_input = torch.from_numpy(
        formatted_df[features].values).float()

    for_input = data_transformations.load_scaled_data(
        for_input, scaler_path, True)

    print("Scores retrieved.")

    # Run model
    pred = model.score(for_input)

    formatted_df["Trustworthy Score"] = pred.detach().numpy()

    formatted_df.to_csv(output_filename_formatted_df)
    print("Formatted df saved.")

    # formatted_df.drop(
    #     features + ['Is generated program correct', 'Prompt id'], inplace=True, axis=1)

    # whole_df = whole_df.merge(
    #     formatted_df, on="Generated Program ID")

    # whole_df.to_csv(output_filename_whole_df)
    # print("Whole df saved.")

    # # Trustworthy score is the score to threshold on

    # for th in thresholds_to_test:

    #     num_correct = 0
    #     num_skipped = 0

    #     # Then sort by test case and get the best program

    #     for prompt in test_prompts:
    #         prompt_df = whole_df[whole_df["Prompt id"] ==
    #                              prompt]

    #         prompt_df = prompt_df.groupby(
    #             "Generated Program ID").first().reset_index()

    #         prompt_df = prompt_df[prompt_df["Trustworthy Score"] >= th]

    #         # If no problems predicted correct - skip
    #         if len(prompt_df) == 0:
    #             num_skipped += 1
    #             continue

    #         program_sorted_df = prompt_df.sort_values(
    #             by=['Trustworthy Score', 'Generated Program ID'], ascending=[False, True])

    #         num_correct += program_sorted_df.iloc[0]["Is generated program correct"]

    #     results.append((num_skipped, num_correct, len(test_prompts)))

    # return results


def logreg_model_score(formatted_for_tensor_conversion_path, model, scaler_path, output_score_filename, features, is_first_iteration, test_prompts=None):
    formatted_for_tensor_conversion_df = pd.read_csv(
        formatted_for_tensor_conversion_path)

    # If testing by specific test prompts
    if test_prompts is None:
        test_prompts = set(formatted_for_tensor_conversion_df["Prompt id"])
    else:
        formatted_for_tensor_conversion_df = formatted_for_tensor_conversion_df[formatted_for_tensor_conversion_df["Prompt id"].isin(
            test_prompts)].copy(deep=True)

    # First apply model
    for_input = torch.from_numpy(
        formatted_for_tensor_conversion_df[features].values).float()

    for_input = data_transformations.load_scaled_data(
        for_input, scaler_path, True)

    # Run model
    pred = model.score(for_input)

    formatted_for_tensor_conversion_df["Trustworthy Score"] = pred.detach(
    ).numpy()

    # if file exists - header is already added
    if not is_first_iteration:
        formatted_for_tensor_conversion_df.to_csv(
            output_score_filename, mode='a', header=False)
    else:
        formatted_for_tensor_conversion_df.to_csv(
            output_score_filename)


def linearcomparison_model_kfold(formatted_for_tensor_conversion_path, model, scaler_path, output_score_filename, test_prompts=None):
    formatted_for_tensor_conversion_df = pd.read_csv(
        formatted_for_tensor_conversion_path)

    # If testing by specific test prompts
    if test_prompts is None:
        test_prompts = set(formatted_for_tensor_conversion_df["Prompt id"])
    else:
        formatted_for_tensor_conversion_df = formatted_for_tensor_conversion_df[formatted_for_tensor_conversion_df["Prompt id"].isin(
            test_prompts)].copy(deep=True)

    # First apply model
    for_input = torch.from_numpy(
        formatted_for_tensor_conversion_df[features].values).float()

    for_input = data_transformations.load_scaled_data(
        for_input, scaler_path, True).float()

    # Run model
    pred = model.score(for_input)

    formatted_for_tensor_conversion_df["Trustworthy Score"] = pred.detach(
    ).numpy()

    # if file exists - header is already added
    if output_score_filename.is_file():
        formatted_for_tensor_conversion_df.to_csv(
            output_score_filename, mode='a', header=False)
    else:
        formatted_for_tensor_conversion_df.to_csv(
            output_score_filename)


def train_weights_linearcomparison():

    path = pathlib.Path()

    folder_path = path / "humaneval"
    model_path = path / 'models'
    scaler_path = path / 'scalers'

    training_X, training_y, test_X, test_y = data_transformations.feature_extraction(folder_path / 'combined_humaneval_evaled.csv',
                                                                                     folder_path / 'is_humaneval_prog_correct.csv', folder_path /
                                                                                     'is_humaneval_test_correct.csv',
                                                                                     folder_path, scaler_path,
                                                                                     cached_filename=folder_path / 'formatted_for_tensor_conversion.csv',
                                                                                     test_split=0.7, is_for_logreg=False)

    print(training_X[0])
    model = linearcomparison.LinearComparison(training_X[0].size(dim=1))

    learning_rate = 0.0001

    # Weight decay 1e-4 w/ learning rate 0.001 seems to work well
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate)

    print("Pre learning rate", training_X[0])

    print(training_X[0])

    epochs = 500

    train = []
    test = []
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss = linearcomparison.train(
            training_X, training_y, model, optimizer)
        print(f"training loss: {train_loss:.5f}")
        test_loss = linearcomparison.test(test_X, test_y, model)

        train.append(train_loss)
        test.append(test_loss)

#     features = ["% test cases passed rank",  "% test cases passed", "Log % test cases passed",

        # "Cluster size rank", "Cluster size (out of total programs)", "Log Cluster size (out of total programs)"]
    print(model.linear.weight)

    # Save model + standard scaler
    torch.save(model, path / 'models' / 'linear_comparison_model.pth')

    # Plot training loss + validation loss
    plot_train_test(train, test, path / 'images' / "loss_linearcomparison.png")


def linearcomparison_with_diff_dataset():
    path = pathlib.Path()

    model = torch.load(path / 'models' / 'linear_comparison_model.pth')
    model.eval()

    mbpp_path = path / 'mbpp_sanitized'

    X, y = data_transformations.feature_extraction_inference(
        mbpp_path / "prog_test_better.csv", mbpp_path /
        "is_mbpp_sanitized_programs_correct.csv", mbpp_path /
        "is_mbpp_sanitized_test_cases_correct.csv", mbpp_path, scaler_path="standardscaler_linearcomparison",
        cached_filename=mbpp_path / "formatted_for_tensor_conversion.csv", is_for_logreg=False)

    linearcomparison.test(X, y, model)


def logreg_with_diff_dataset(model_path, scaler_path, formatted_for_tensor_conversion_path, output_score_filename, features):
    model = torch.load(model_path)
    model.eval()

    logreg_model_score(formatted_for_tensor_conversion_path,
                       model, scaler_path, output_score_filename, features, True)


# def logreg_model_with_diff_dataset(model):
#     model, scaler = load_linearcomparison_model()

#     os.chdir('mbpp_sanitized')

#     cor, skipped, tot = model_combined_with_my_clustering_method(
#         "prog_test_better.csv_whole.csv", model)

#     print(cor, skipped, tot)

def train_humaneval():
    figure_path = current_path / "images"

    folder_path = current_path / "humaneval_0.8_top_p_1"
    scaler_path = current_path / 'scalers'
    model_path = current_path / 'models'

    # data_transformations.feature_extraction_no_test_split(folder_path / 'prog_test.csv',
    #                                                       folder_path / 'is_humaneval_temp_0.8_whole_set_programs_correct.csv', folder_path /
    #                                                       'is_humaneval_temp_0.8_whole_set_test_cases_correct.csv',
    #                                                       folder_path, scaler_path,
    #                                                       for_data_loader=False, should_load_scale_data=False, should_return_problem_list=True)

    # dataset_metrics.get_top_n_testcases(
    #     folder_path / 'whole_dataset.csv', folder_path / 'formatted_df_scores_kfold.csv',
    #     folder_path / 'is_humaneval_temp_0.8_whole_set_programs_top_p_0.95_non_uniq_correct.csv',
    #     folder_path / 'is_humaneval_temp_0.8_whole_set_test_cases_top_p_0.95_correct.csv', 3, 149)

    # train_weights_logreg_kfold(folder_path /
    #                            'formatted_for_tensor_conversion_normal_100.csv', scaler_path, data_transformations.features,
    #                            folder_path / 'formatted_df_scores_kfold_trained_0.8_tested_0.2.csv', diff_formatted_for_tensor_conversion_path=current_path / 'humaneval_0.2' / 'formatted_for_tensor_conversion.csv')

    # train_weights_logreg(
    #     folder_path / 'formatted_for_tensor_conversion_combined_formal_specs.csv', scaler_path, 'humaneval_combined_formal_specs', data_transformations.features + data_transformations.formalspecs_features, model_path / 'humaneval_combined_formal_specs')

    # train_weights_logreg(
    #     folder_path / 'formatted_for_tensor_conversion_formal_specs_100.csv', scaler_path, 'humaneval_formal_specs_100', data_transformations.features, model_path / 'humaneval_formal_specs_100')

    # train_weights_logreg(
    #     folder_path / 'formatted_for_tensor_conversion_normal_100.csv', scaler_path, 'humaneval_normal_100', data_transformations.features, model_path / 'humaneval_normal_100')

    # logreg_with_diff_dataset(model_path / 'humaneval_formal_specs_100', scaler_path / 'humaneval_formal_specs_100', current_path / 'mbpp_0.8_top_p_1' /
    #                          'formatted_for_tensor_conversion_formal_specs.csv', folder_path / 'formatted_df_scores_trained_humaneval_tested_mbpp_formal_specs_only.csv', data_transformations.features)

    # logreg_with_diff_dataset(model_path / 'humaneval_normal_100', scaler_path / 'humaneval_normal_100', current_path / 'mbpp_0.8_top_p_1' /
    #                          'formatted_for_tensor_conversion_normal_100.csv', folder_path / 'formatted_df_scores_trained_humaneval_tested_mbpp_combined_normal_100.csv', data_transformations.features)

    # logreg_with_diff_dataset(model_path / 'humaneval_combined_formal_specs', scaler_path / 'humaneval_combined_formal_specs', current_path / 'mbpp_0.8_top_p_1' /
    #                          'formatted_for_tensor_conversion_combined_formal_specs.csv', folder_path / 'formatted_df_scores_trained_humaneval_tested_mbpp_combined_formal_specs.csv', data_transformations.features + data_transformations.formalspecs_features)

    # print(dataset_metrics.thresholded_pass_at_k(
    #     1, folder_path / 'formatted_df_scores_trained_humaneval_tested_mbpp_combined_formal_specs.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(
    #     1, folder_path / 'formatted_df_scores_trained_humaneval_tested_mbpp_combined_normal_100.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(
    #     1, folder_path / 'formatted_df_scores_trained_humaneval_tested_mbpp_formal_specs_only.csv', -10))

    # train_weights_logreg_kfold(None, None, None,
    #                            folder_path, scaler_path, data_transformations.features + data_transformations.formalspecs_features, folder_path /
    #                            'formatted_for_tensor_conversion_combined_formal_specs.csv',
    #                            folder_path / 'formatted_df_scores_kfold_combined_formal_specs_neural_net.csv')

    # train_weights_logreg(folder_path / 'prog_test.csv', folder_path /
    #                      'is_humaneval_temp_0.8_whole_set_programs_correct.csv',
    #                      folder_path / 'is_humaneval_temp_0.8_whole_set_test_cases_correct.csv',
    #                      folder_path, scaler_path, 'scaler_trained_from_humaneval_0.8', folder_path / 'formatted_for_tensor_conversion.csv', model_path / 'logreg_model_trained_on_humaneval.pth')

    # dataset_metrics.get_random_n_testcases_to_df(folder_path / 'whole_dataset_normal_100.csv', folder_path /
    #                                              'formatted_df_scores_kfold_normal_features_100.csv',
    #                                              folder_path / 'is_humaneval_gen_program_0.8_whole_dataset_correct.csv', folder_path /
    #                                              'is_humaneval_gen_testcase_0.8_whole_dataset_correct.csv',
    #                                              folder_path / 'random_5_test_cases_normal_100.csv', 5)

    # print(dataset_metrics.codet_model(
    #     1, 'archive/mbpp_sanitized_0.8_whole_dataset/whole_dataset.csv'))
    # print(dataset_metrics.pass_at_k(
    #     1, 'archive/mbpp_sanitized_0.8_whole_dataset/whole_dataset.csv'))
    # print(dataset_metrics.thresholded_pass_at_k(1, folder_path /
    #       'formatted_df_scores_kfold_trained_0.2_tested_0.8.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(10, folder_path /
    #       'formatted_df_scores_kfold_combined_formal_specs_neural_net.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(1, folder_path /
    #       'formatted_df_scores_kfold_combined_formal_specs.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(1, folder_path /
    #       'formatted_df_scores_kfold_normal_200.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(1, folder_path /
    #       'formatted_df_scores_kfold_normal_features_100.csv', -10))

    # data_transformations.merge_formal_specs_normal_test_cases(folder_path / 'formatted_for_tensor_conversion_formal_specs.csv',
    #                                                           folder_path / 'formatted_for_tensor_conversion.csv',
    #                                                           folder_path / 'formatted_for_tensor_conversion_combined_formal_specs.csv')

    # print(dataset_metrics.pass_at_k(
    #     1, folder_path / 'is_humaneval_temp_0.8_whole_set_programs_correct.csv'))
    # dataset_metrics.num_positives(
    #     folder_path / 'formatted_for_tensor_conversion_combined_formal_specs.csv')

    # logreg_with_diff_dataset(
    #     model_path / "logreg_model_trained_on_mbpp.pth", scaler_path / "scaler_trained_from_mbpp_0.8", folder_path /
    #     "whole_dataset.csv", folder_path, folder_path / "whole_df_trained_on_mbpp.csv",
    #     folder_path / "formatted_df_trained_on_mbpp.csv"
    # )

    # train_weights_linearcomparison_kfold(None, None, None,
    #                                      folder_path, scaler_path, folder_path /
    #                                      'formatted_for_tensor_conversion_linear_comparison.csv',
    #                                      folder_path / 'formatted_df_scores_from_humaneval_linear_comparison.csv')

    # print(dataset_metrics.thresholded_pass_at_k(1, folder_path /
    #       'formatted_df_scores_kfold_combined_formal_specs.csv', -10))

    # data_transformations.transform_dataset(folder_path / 'prog_test.csv', folder_path / 'is_humaneval_temp_0.8_whole_set_programs_correct.csv',
    #                                        folder_path / 'is_humaneval_temp_0.8_whole_set_test_cases_correct.csv', folder_path, data_transformations.small_num_to_prevent_underflow,
    #                                        data_transformations.features, False, None, None,
    #                                        'whole_dataset_lin_comparison.csv', 'formatted_for_tensor_conversion_linear_comparison.csv')

    # data_transformations.merge_formal_specs_normal_test_cases(folder_path / 'formatted_for_tensor_conversion_formal_specs_100.csv',
    #                                                           folder_path / 'formatted_for_tensor_conversion_normal_100.csv', folder_path / 'formatted_for_tensor_conversion_combined_formal_specs.csv')


def train_mbpp():
    figure_path = current_path / "images"

    folder_path = current_path / "mbpp_0.8_top_p_1"
    scaler_path = current_path / 'scalers'
    model_path = current_path / 'models'

    # data_transformations.transform_dataset(folder_path / 'prog_test_normal_100.csv', folder_path / 'is_mbpp_0.8_whole_dataset_programs_correct.csv',
    #                                        folder_path / 'is_mbpp_0.8_whole_dataset_test_cases_correct.csv', folder_path, data_transformations.small_num_to_prevent_underflow,
    #                                        data_transformations.features, True, None, None,
    #                                        'whole_dataset_normal_100.csv', 'formatted_for_tensor_conversion_normal_100.csv')

    # data_transformations.transform_dataset(folder_path / 'prog_test_formal_specs.csv', folder_path / 'is_humaneval_temp_0.8_whole_set_programs_correct.csv',
    #                                        folder_path / 'is_humaneval_gen_testcase_0.8_formal_specs_correct.csv', folder_path, data_transformations.small_num_to_prevent_underflow,
    #                                        data_transformations.features, True, None, None,
    #                                        'whole_dataset_formal_specs.csv', 'formatted_for_tensor_conversion_formal_specs.csv')

    # data_transformations.get_alphacode_clustering(current_path / "mbpp_sanitized_0.8_alphacode" / 'prog_test_alphacode.csv',
    #                                               current_path / "mbpp_sanitized_0.8_alphacode" / 'alphacode_clustering_info.csv')

    # print(dataset_metrics.thresholded_pass_at_k(1, folder_path /
    #       'formatted_df_scores_kfold_mbpp_combined_formal_specs_trained_mlp.csv', -10))

    # print(dataset_metrics.codet_model(
    #     1, folder_path / 'whole_dataset_normal_100.csv'))

    # print(dataset_metrics.pass_at_k(1, folder_path /
    #       'is_mbpp_sanitized_programs_temp_0.8_actual_whole_set_correct.csv'))

    # train_weights_logreg_kfold(folder_path /
    #                            'formatted_for_tensor_conversion_combined_formal_specs.csv', scaler_path, data_transformations.features +
    #                            data_transformations.formalspecs_features,
    #                            folder_path / 'formatted_df_scores_kfold_mbpp_combined_formal_specs_trained_mlp', use_mlp=True)

    # train_weights_logreg(folder_path / 'prog_test.csv', folder_path /
    #                      'is_mbpp_sanitized_programs_temp_0.8_actual_whole_set_correct.csv',
    #                      folder_path / 'is_mbpp_sanitized_testcases_temp_0.8_actual_whole_set_correct.csv',
    #                      folder_path, scaler_path, "scaler_trained_from_mbpp_0.8", folder_path / 'formatted_for_tensor_conversion.csv', model_path / 'logreg_model_trained_on_mbpp.pth')

    # logreg_with_diff_dataset(
    #     model_path / 'logreg_model_trained_on_humaneval.pth', scaler_path /
    #     "scaler_trained_from_humaneval_0.8",
    #     folder_path / "whole_dataset.csv", folder_path, folder_path /
    #     "whole_df_trained_on_humaneval.csv",
    #     folder_path / "formatted_df_trained_on_humaneval.csv"
    # )

    # data_transformations.merge_formal_specs_normal_test_cases(folder_path / 'formatted_for_tensor_conversion_formal_specs.csv',
    #                                                           folder_path / 'formatted_for_tensor_conversion_normal_100.csv', folder_path / 'formatted_for_tensor_conversion_combined_formal_specs.csv')

    # print(codet_model('mbpp_sanitized_0.8_whole_dataset/whole_dataset.csv'))

    # train_weights_logreg(
    #     folder_path / 'formatted_for_tensor_conversion_combined_formal_specs.csv', scaler_path, 'mbpp_combined_formal_specs', data_transformations.features + data_transformations.formalspecs_features, model_path / 'mbpp_combined_formal_specs')

    # train_weights_logreg(
    #     folder_path / 'formatted_for_tensor_conversion_formal_specs.csv', scaler_path, 'mbpp_formal_specs_100', data_transformations.features, model_path / 'mbpp_formal_specs_100')

    # train_weights_logreg(
    #     folder_path / 'formatted_for_tensor_conversion_normal_100.csv', scaler_path, 'mbpp_normal_100', data_transformations.features, model_path / 'mbpp_normal_100')

    # logreg_with_diff_dataset(model_path / 'mbpp_formal_specs_100', scaler_path / 'mbpp_formal_specs_100', current_path / 'humaneval_0.8_top_p_1' /
    #                          'formatted_for_tensor_conversion_formal_specs_100.csv', folder_path / 'formatted_df_scores_trained_mbpp_tested_humaneval_formal_specs_only.csv', data_transformations.features)

    # logreg_with_diff_dataset(model_path / 'mbpp_normal_100', scaler_path / 'mbpp_normal_100', current_path / 'humaneval_0.8_top_p_1' /
    #                          'formatted_for_tensor_conversion_normal_100.csv', folder_path / 'formatted_df_scores_trained_mbpp_tested_humanevalcombined_normal_100.csv', data_transformations.features)

    # logreg_with_diff_dataset(model_path / 'mbpp_combined_formal_specs', scaler_path / 'mbpp_combined_formal_specs', current_path / 'humaneval_0.8_top_p_1' /
    #                          'formatted_for_tensor_conversion_combined_formal_specs.csv', folder_path / 'formatted_df_scores_trained_mbpp_tested_humaneval_combined_formal_specs.csv', data_transformations.features + data_transformations.formalspecs_features)

    # print(dataset_metrics.thresholded_pass_at_k(
    #     1, folder_path / 'formatted_df_scores_trained_mbpp_tested_humaneval_formal_specs_only.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(
    #     1, folder_path / 'formatted_df_scores_trained_mbpp_tested_humanevalcombined_normal_100.csv', -10))
    # print(dataset_metrics.thresholded_pass_at_k(
    #     1, folder_path / 'formatted_df_scores_trained_mbpp_tested_humaneval_combined_formal_specs.csv', -10))


def main_loop():

    # train_humaneval()
    train_mbpp()

    # x = [[(0.6976744186046512, 0, 0.6976744186046512), (0.6976744186046512, 0, 0.6976744186046512), (0.6976744186046512, 0, 0.6976744186046512), (0.7073170731707317, 2, 0.6744186046511628), (0.7, 3, 0.6511627906976745), (0.717948717948718, 4, 0.6511627906976745), (0.7567567567567568, 6, 0.6511627906976745), (0.7777777777777778, 7, 0.6511627906976745), (0.7777777777777778, 7, 0.6511627906976745), (0.7777777777777778, 7, 0.6511627906976745), (0.7777777777777778, 7, 0.6511627906976745), (0.8, 8, 0.6511627906976745), (0.7878787878787878, 10, 0.6046511627906976), (0.78125, 11, 0.5813953488372093), (0.8, 13, 0.5581395348837209), (0.8518518518518519, 16, 0.5348837209302325), (0.875, 19, 0.4883720930232558), (0.9523809523809523, 22, 0.46511627906976744), (1.0, 27, 0.37209302325581395), (1.0, 38, 0.11627906976744186), (1, 43, 0.0)], [(0.6976744186046512, 0, 0.6976744186046512), (0.6976744186046512, 0, 0.6976744186046512), (0.6976744186046512, 0, 0.6976744186046512), (0.7142857142857143, 1, 0.6976744186046512), (0.7317073170731707, 2, 0.6976744186046512), (0.7692307692307693, 4, 0.6976744186046512), (0.7894736842105263, 5, 0.6976744186046512), (0.8055555555555556, 7, 0.6744186046511628), (0.8, 8, 0.6511627906976745), (0.8484848484848485, 10, 0.6511627906976745), (0.8484848484848485, 10, 0.6511627906976745), (0.8709677419354839, 12, 0.627906976744186), (0.8709677419354839, 12, 0.627906976744186), (0.8666666666666667, 13, 0.6046511627906976), (0.9259259259259259, 16, 0.5813953488372093), (0.9565217391304348, 20, 0.5116279069767442), (0.9473684210526315, 24, 0.4186046511627907), (1.0, 30, 0.3023255813953488), (1.0, 33, 0.23255813953488372), (1, 43, 0.0), (1, 43, 0.0)], [(0.6744186046511628, 0, 0.6744186046511628), (0.6744186046511628, 0, 0.6744186046511628), (0.6744186046511628, 0, 0.6744186046511628), (0.6744186046511628, 0, 0.6744186046511628), (0.6904761904761905, 1, 0.6744186046511628), (0.7073170731707317, 2, 0.6744186046511628), (0.7435897435897436, 4, 0.6744186046511628), (0.7567567567567568, 6, 0.6511627906976745), (0.7567567567567568, 6, 0.6511627906976745), (0.7777777777777778, 7, 0.6511627906976745), (0.8235294117647058, 9, 0.6511627906976745), (0.8666666666666667, 13, 0.6046511627906976), (0.8620689655172413, 14, 0.5813953488372093), (0.8461538461538461, 17, 0.5116279069767442), (0.84, 18, 0.4883720930232558), (0.875, 19, 0.4883720930232558), (0.85, 23, 0.3953488372093023), (0.8421052631578947, 24, 0.37209302325581395), (0.7692307692307693, 30, 0.23255813953488372), (1.0, 37, 0.13953488372093023), (1, 43, 0.0)], [(0.5581395348837209, 0, 0.5581395348837209), (0.5581395348837209, 0, 0.5581395348837209), (0.5581395348837209, 0, 0.5581395348837209), (0.5853658536585366, 2, 0.5581395348837209), (0.6052631578947368, 5, 0.5348837209302325), (0.6176470588235294, 9, 0.4883720930232558), (0.6363636363636364, 10, 0.4883720930232558), (0.6363636363636364, 10, 0.4883720930232558), (0.625, 11, 0.46511627906976744), (0.625, 11, 0.46511627906976744), (0.6451612903225806, 12, 0.46511627906976744), (0.6451612903225806, 12, 0.46511627906976744), (0.6551724137931034, 14, 0.4418604651162791), (0.6785714285714286, 15, 0.4418604651162791), (0.72, 18, 0.4186046511627907), (0.7619047619047619, 22, 0.37209302325581395), (0.7777777777777778, 25, 0.32558139534883723), (0.8, 28, 0.27906976744186046), (0.9166666666666666, 31, 0.2558139534883721), (1.0, 41, 0.046511627906976744), (1, 43, 0.0)], [(0.7619047619047619, 0, 0.7619047619047619), (0.7619047619047619, 0, 0.7619047619047619), (0.7619047619047619, 0, 0.7619047619047619), (0.8648648648648649, 5, 0.7619047619047619), (0.8648648648648649, 5, 0.7619047619047619), (0.8648648648648649, 5, 0.7619047619047619), (0.8611111111111112, 6, 0.7380952380952381), (0.8611111111111112, 6, 0.7380952380952381), (0.8823529411764706, 8, 0.7142857142857143), (0.8823529411764706, 8, 0.7142857142857143), (0.8787878787878788, 9, 0.6904761904761905), (0.90625, 10, 0.6904761904761905), (0.90625, 10, 0.6904761904761905), (0.9333333333333333, 12, 0.6666666666666666), (0.9615384615384616, 16, 0.5952380952380952), (0.9565217391304348, 19, 0.5238095238095238), (0.95,
    #                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                22, 0.4523809523809524), (0.9411764705882353, 25, 0.38095238095238093), (1.0, 32, 0.23809523809523808), (1.0, 40, 0.047619047619047616), (1, 42, 0.0)], [(0.6190476190476191, 0, 0.6190476190476191), (0.6190476190476191, 0, 0.6190476190476191), (0.6190476190476191, 0, 0.6190476190476191), (0.6341463414634146, 1, 0.6190476190476191), (0.6410256410256411, 3, 0.5952380952380952), (0.7142857142857143, 7, 0.5952380952380952), (0.75, 10, 0.5714285714285714), (0.7666666666666667, 12, 0.5476190476190477), (0.8148148148148148, 15, 0.5238095238095238), (0.9166666666666666, 18, 0.5238095238095238), (0.9130434782608695, 19, 0.5), (0.9090909090909091, 20, 0.47619047619047616), (0.9523809523809523, 21, 0.47619047619047616), (0.9523809523809523, 21, 0.47619047619047616), (0.9523809523809523, 21, 0.47619047619047616), (1.0, 24, 0.42857142857142855), (1.0, 26, 0.38095238095238093), (1.0, 27, 0.35714285714285715), (1.0, 30, 0.2857142857142857), (1.0, 40, 0.047619047619047616), (1, 42, 0.0)], [(0.5, 0, 0.5), (0.5, 0, 0.5), (0.5, 0, 0.5), (0.5384615384615384, 3, 0.5), (0.5675675675675675, 5, 0.5), (0.6, 7, 0.5), (0.6176470588235294, 8, 0.5), (0.6176470588235294, 8, 0.5), (0.65625, 10, 0.5), (0.65625, 10, 0.5), (0.6923076923076923, 16, 0.42857142857142855), (0.7083333333333334, 18, 0.40476190476190477), (0.6956521739130435, 19, 0.38095238095238093), (0.7368421052631579, 23, 0.3333333333333333), (0.8125, 26, 0.30952380952380953), (0.8666666666666667, 27, 0.30952380952380953), (0.8571428571428571, 28, 0.2857142857142857), (1.0, 32, 0.23809523809523808), (1.0, 34, 0.19047619047619047), (1, 42, 0.0), (1, 42, 0.0)], [(0.5952380952380952, 0, 0.5952380952380952), (0.5952380952380952, 0, 0.5952380952380952), (0.6097560975609756, 1, 0.5952380952380952), (0.625, 2, 0.5952380952380952), (0.6944444444444444, 6, 0.5952380952380952), (0.6944444444444444, 6, 0.5952380952380952), (0.6944444444444444, 6, 0.5952380952380952), (0.7142857142857143, 7, 0.5952380952380952), (0.7142857142857143, 7, 0.5952380952380952), (0.7142857142857143, 7, 0.5952380952380952), (0.7272727272727273, 9, 0.5714285714285714), (0.75, 10, 0.5714285714285714), (0.75, 10, 0.5714285714285714), (0.7741935483870968, 11, 0.5714285714285714), (0.8148148148148148, 15, 0.5238095238095238), (0.8095238095238095, 21, 0.40476190476190477), (0.7777777777777778, 24, 0.3333333333333333), (0.7333333333333333, 27, 0.2619047619047619), (0.75, 30, 0.21428571428571427), (0.8, 37, 0.09523809523809523), (1, 42, 0.0)], [(0.6666666666666666, 0, 0.6666666666666666), (0.6666666666666666, 0, 0.6666666666666666), (0.6666666666666666, 0, 0.6666666666666666), (0.6829268292682927, 1, 0.6666666666666666), (0.6829268292682927, 1, 0.6666666666666666), (0.6829268292682927, 1, 0.6666666666666666), (0.7, 2, 0.6666666666666666), (0.717948717948718, 3, 0.6666666666666666), (0.7777777777777778, 6, 0.6666666666666666), (0.7647058823529411, 8, 0.6190476190476191), (0.78125, 10, 0.5952380952380952), (0.7741935483870968, 11, 0.5714285714285714), (0.8, 12, 0.5714285714285714), (0.8214285714285714, 14, 0.5476190476190477), (0.7916666666666666, 18, 0.4523809523809524), (0.782608695652174, 19, 0.42857142857142855), (0.8333333333333334, 24, 0.35714285714285715), (0.8235294117647058, 25, 0.3333333333333333), (1.0, 32, 0.23809523809523808), (1.0, 41, 0.023809523809523808), (1, 42, 0.0)], [(0.7380952380952381, 0, 0.7380952380952381), (0.7380952380952381, 0, 0.7380952380952381), (0.7380952380952381, 0, 0.7380952380952381), (0.7380952380952381, 0, 0.7380952380952381), (0.775, 2, 0.7380952380952381), (0.775, 2, 0.7380952380952381), (0.8108108108108109, 5, 0.7142857142857143), (0.8108108108108109, 5, 0.7142857142857143), (0.8055555555555556, 6, 0.6904761904761905), (0.8, 7, 0.6666666666666666), (0.8, 7, 0.6666666666666666), (0.7941176470588235, 8, 0.6428571428571429), (0.7931034482758621, 13, 0.5476190476190477), (0.8076923076923077, 16, 0.5), (0.84, 17, 0.5), (0.8260869565217391, 19, 0.4523809523809524), (0.7894736842105263, 23, 0.35714285714285715), (0.8125, 26, 0.30952380952380953), (0.7692307692307693, 29, 0.23809523809523808), (0.5, 40, 0.023809523809523808), (1, 42, 0.0)]]

    # visualize_from_kfold_result(x)

    # writer = SummaryWriter()
    # train_weights_linearcomparison()
    # linearcomparison_with_diff_dataset()

    # logreg_with_diff_dataset(
    #     model_path, scaler_path / "scaler_not_kfold_log_reg", current_path / "humaneval_0.8_whole_dataset" / "whole_dataset.csv", figure_path / "mbpp_on_humaneval")

    # print(dataset_metrics.pass_at_k(1, folder_path /
    #       'is_humaneval_temp_0.8_whole_set_programs_correct.csv'))

    # print(dataset_metrics.pass_at_k(
    #     10, r'humaneval/is_humaneval_prog_correct.csv'))
    # train_weights_logreg()

    # print(dataset_metrics.pass_at_k(
    #     10, 'mbpp_sanitized/is_mbpp_sanitized_programs_correct.csv'))


if __name__ == "__main__":
    start_time = time.time()

    # Folder of wherever this file is located
    current_path = pathlib.Path(__file__).parent

    main_loop()

    print(f"Program took: {round(time.time() - start_time, 2)} seconds")
