from all_data import get_all_data

df = get_all_data()


# 9 domains, 10 instances each, 5 runs.

grouped = df[df["batch_id"] == "37290af5-1d7d-4d35-88fe-6fcf8ffc5868"]
grouped = grouped[grouped["domain"] != "TriangleTireworld_MDP_ippc2014"]

grouped = grouped[grouped["domain"] == "SysAdmin_MDP_ippc2011"]

y1 = grouped[~grouped["is_train"]]["score"]
# y2 = grouped[~grouped["is_train"]]["mlp"]
# y3 = grouped[~grouped["is_train"]]["prost"]

# plot boxplots
# import matplotlib.pyplot as plt
# import seaborn as sns

# plt.figure(figsize=(10, 6))
# sns.boxplot(data=[y1, y2, y3], palette="Set2")
# plt.xticks([0, 1, 2], ["GNN", "MLP", "Prost"])
# plt.ylabel("Score")
# plt.title("Score Distribution for GNN, MLP, and Prost")
# plt.grid(True)
# plt.savefig("score_distribution.png")
# plt.show()

# plot histograms
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10, 6))
sns.histplot(y1, kde=True, color="blue", label="GNN", stat="density", bins=30)
# sns.histplot(y2, kde=True, color="orange", label="MLP", stat="density", bins=30)
# sns.histplot(y3, kde=True, color="green", label="Prost", stat="density", bins=30)
plt.xlabel("Score")
plt.ylabel("Density")
plt.title("Score Distribution for GNN, MLP, and Prost")
plt.legend()
plt.grid(True)
plt.savefig("score_distribution_histogram.png")
plt.show()
