import numpy as np
import matplotlib.pyplot as plt
import scienceplots

plt.style.use(['science', 'no-latex'])

records = np.load("records_three_2_2_pca.npy")
sample_sizes = np.arange(10000, 210000, 10000) / 1000

# records = records[:10]
# sample_sizes = sample_sizes[:10]

fig, ax = plt.subplots(2, 1, figsize=(3, 6), layout="constrained")
ax[0].plot(sample_sizes, np.mean(records[:, :, 0], axis=1), c='r', label="Proposed")
ax[0].plot(sample_sizes, np.mean(records[:, :, 1], axis=1), c='b', label="EM")
ax[0].plot(sample_sizes, np.mean(records[:, :, 2], axis=1), c='g', label="PCA+EM")
ax[0].errorbar(sample_sizes, np.mean(records[:, :, 0], axis=1), yerr=np.std(records[:, :, 0], axis=1), c='r')
ax[0].errorbar(sample_sizes, np.mean(records[:, :, 1], axis=1), yerr=np.std(records[:, :, 1], axis=1), c='b')
ax[0].errorbar(sample_sizes, np.mean(records[:, :, 2], axis=1), yerr=np.std(records[:, :, 2], axis=1), c='g')
# ax[0].set_xlabel("sample size $n/1000$", fontsize=12)
# ax[0].set_ylabel("Average 1-Wasserstein error in $96$ trials", fontsize=12)
ax[0].set_ylim(0., 0.4)
ax[0].legend()

ax[1].plot(sample_sizes, np.mean(records[:, :, 3], axis=1), c='r', label="Proposed")
ax[1].plot(sample_sizes, np.mean(records[:, :, 4], axis=1), c='b', label="EM")
ax[1].plot(sample_sizes, np.mean(records[:, :, 5], axis=1), c='g', label="PCA+EM")
ax[1].set_xlabel("sample size $n/1000$", fontsize=12)
# ax[1].set_ylabel("Average time(s) taken in $96$ trials", fontsize=12)
ax[1].legend()

plt.savefig("comparisonEM_three_2_2_pca.pdf")

