from all_data import get_all_data
from scipy.stats import tukey_hsd, permutation_test

df = get_all_data(include_mlp=False)


df = df[df["batch_id"] == "054ae5d1-6b19-41b8-811f-ecd88ed4a78e"]
df = df[df["domain"] != "TriangleTireworld_MDP_ippc2014"]

train_set = df[df["is_train"]]
test_set = df[~df["is_train"]]

tukey = tukey_hsd(
    test_set["score"],  # 0
    train_set["score"],  # 1
    test_set["prost"],  # 2
    train_set["prost"],  # 3
)

print(tukey)

comps = [
	(("test", "train"),(test_set["score"], train_set["score"])),
	(("test", "prost"),(test_set["score"], test_set["prost"])),
	(("train", "prost"),(train_set["score"], train_set["prost"])),
]


for (labels, comp) in comps:

    permutation = permutation_test(
        comp,
        lambda x, y, axis: x.mean(axis=axis) - y.mean(axis=axis),
        n_resamples=1e5,
        vectorized=True,
        alternative="two-sided",
    )

    #print(permutation)

    print(f"{permutation.statistic:.2f} & {permutation.pvalue:.2f} \\\\")

    # plot the distribution of the permutation test
    import matplotlib.pyplot as plt
    import seaborn as sns

    plt.figure(figsize=(10, 6))
    sns.histplot(
        permutation.null_distribution,
        kde=False,
        color="blue",
        stat="percent",
        bins=100,
    )
    # place a vertical line at the test statistic
    plt.axvline(
        permutation.statistic,
        color="red",
        linestyle="--",
        label=f"Observed statistic: {permutation.statistic:.4f}, $p$-value: {permutation.pvalue:.4f}",
    )
    plt.legend()
    plt.xlabel(f"Value of statistic ({labels})")
    plt.ylabel("Density (%)")
    plt.ylim(0, 4)
    plt.grid(False)
    
    
    
    plt.savefig(f"{labels}_permutation_test_{permutation.statistic:.2f}.pdf")
    plt.close()
