import numpy as np
import pandas as pd
from scipy.stats import ttest_ind_from_stats



D = 16
col_names = []
for j in range(D):
    col_names.append(j + 1)

models = ["vanilla5", "taf5", "gtaf5", "mtaf5fix", "mtaf5"]
settings = ["df2h1", "df2h2","df2h4","df2h8","df3h1","df3h2","df3h4","df3h8"]

print("----------------------------------------------------------")
print("-----------------     Log-Likelihood     -----------------")
print("----------------------------------------------------------")
train_summary = pd.DataFrame()
val_summary = pd.DataFrame()
test_summary = pd.DataFrame(columns=settings, index=models)
test_summary_std = pd.DataFrame(columns=settings, index=models)
for model in models:
    for setting in settings:
        if model!="mtaf5fix":
            PATH = setting + "/likelihood/" + model
        else:
            PATH = setting + "/fix/likelihood/mtaf5"
        try:
            test = pd.read_csv(PATH + "_test.txt", delimiter=" ", names=["loss", "modelnr"])
            train = pd.read_csv(PATH + "_train.txt", delimiter=" ", names=["loss", "modelnr"])
            val = pd.read_csv(PATH + "_val.txt", delimiter=" ", names=["loss", "modelnr"])

            if len(test["loss"])!=10:
                print(setting)
                print(model)
            # test loss:
            test.replace([np.inf, -np.inf], np.nan, inplace=True)
            print("{}: Number of NaNs or inf = {}".format(model + " " + setting, sum(np.isnan(test["loss"]))))
            test["loss"] = test["loss"].replace(np.inf, np.nan)
            test_summary.at[model, setting] = test["loss"].dropna(how="all").mean()
            test_summary_std.at[model, setting] = test["loss"].dropna(how="all").std()
        except:
            print("No data was found for the model " + model + " in the setting " + setting + ".")



print("Average Test Loss:")
print(test_summary.round(2))
print("Standard Deviation Test Loss:")
print(test_summary_std)

show_ttest_output = False
if show_ttest_output:
    for setting in settings:
        for model in models:
            if model=="taf5" and setting=="df2h1":
                num_samps = 9 # these are hard-coded number of samples. These numbers depend on the number of crashes during training in this setting
            elif model=="taf5" and setting=="df2h2":
                num_samps = 8 # these are hard-coded number of samples. These numbers depend on the number of crashes during training in this setting
            elif model == "gtaf5" and setting=="df2h2":
                num_samps = 9 # these are hard-coded number of samples. These numbers depend on the number of crashes during training in this setting
            else:
                num_samps = 10 # these are hard-coded number of samples. These numbers depend on the number of crashes during training in this setting
            if model !="mtaf5":
                _, p = ttest_ind_from_stats(test_summary.at["mtaf5", setting], test_summary_std.at["mtaf5", setting], 10, test_summary.at[model, setting], test_summary_std.at[model, setting], num_samps, alternative="less")

                if p<0.05:
                    print("mTAF5 is significantly better than " + model + " in the setting " + setting)
                else:
                    print("mTAF5 is NOT significantly better than " + model + " in the setting " + setting)
                print("---------------")

    for setting in settings:
        for model in models:
            if model=="taf5" and setting=="df2h1":
                num_samps = 9
            elif model=="taf5" and setting=="df2h2":
                num_samps = 8
            elif model == "gtaf5" and setting=="df2h2":
                num_samps = 9
            else:
                num_samps = 10
            if model !="mtaf5fix":
                _, p = ttest_ind_from_stats(test_summary.at["mtaf5fix", setting], test_summary_std.at["mtaf5fix", setting], 10, test_summary.at[model, setting], test_summary_std.at[model, setting], num_samps, alternative="less")

                if p<0.05:
                    print("mTAF5fix is significantly better than " + model + " in the setting " + setting)
                else:
                    print("mTAF5fix is NOT significantly better than " + model + " in the setting " + setting)
                print("---------------")
