import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

fig, axs = plt.subplots(1, 5, figsize=(5.5 * 5, 5))
axis_lim = 1.1

sns.set_context("poster")
sns.set_theme(style="white", font_scale=3.2)


data = np.load("toy.npy", allow_pickle=True).item()
actions = data["actions"]
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[0])

# axs[0].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='#d62728')
axs[0].set_xlabel("")
axs[0].set_ylabel("")
axs[0].tick_params(axis="x", labelsize=20)
axs[0].tick_params(axis="y", labelsize=20)
axs[0].set_title("Ground Truth", fontsize=38, pad=20)

# Load the .npy file
actions = np.load("diffcps15.npy", allow_pickle=True)
# axs[1].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='#d62728')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[1])
axs[1].set_xlabel("")
axs[1].set_ylabel("")
axs[1].tick_params(axis="x", labelsize=20)
axs[1].tick_params(axis="y", labelsize=20)
axs[1].set_title("DiffCPS (T=15)", fontsize=38, pad=20)

actions = np.load("awr.npy", allow_pickle=True)
# axs[2].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='#d62728')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[2])

axs[2].set_xlabel("")
axs[2].set_ylabel("")
axs[2].tick_params(axis="x", labelsize=20)
axs[2].tick_params(axis="y", labelsize=20)
axs[2].set_title("AWR", fontsize=38, pad=20)

actions = np.load("dql15.npy", allow_pickle=True)
# axs[3].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='#d62728')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[3])
axs[3].set_xlabel("")
axs[3].set_ylabel("")
axs[3].tick_params(axis="x", labelsize=20)
axs[3].tick_params(axis="y", labelsize=20)
axs[3].set_title("DQL (T=15)", fontsize=38, pad=20)

actions = np.load("sfbc15.npy", allow_pickle=True)
# axs[4].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='#d62728')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[4])
axs[4].set_xlabel("")
axs[4].set_ylabel("")
axs[4].tick_params(axis="x", labelsize=20)
axs[4].tick_params(axis="y", labelsize=20)
axs[4].set_title("SfBC (T=15)", fontsize=38, pad=20)


plt.tight_layout()
plt.savefig("toy_res.pdf", format="pdf")

plt.clf()
fig, axs = plt.subplots(1, 5, figsize=(5.5 * 5, 5))
axis_lim = 1.1

data = np.load("toy.npy", allow_pickle=True).item()
actions = data['actions']
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[0])
axs[0].set_xlabel("")
axs[0].set_ylabel("")
axs[0].tick_params(axis='x', labelsize=20)
axs[0].tick_params(axis='y', labelsize=20)
axs[0].set_title("Ground Truth", fontsize=38,pad=20)

# Load the .npy file
actions = np.load("diffcps15.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[1])
axs[1].set_xlabel('')
axs[1].set_ylabel('')
axs[1].tick_params(axis='x', labelsize=20)
axs[1].tick_params(axis='y', labelsize=20)
axs[1].set_title("DiffCPS (T=15)", fontsize=38,pad=20)

actions = np.load("diffcps25.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[2])
axs[2].set_xlabel('')
axs[2].set_ylabel('')
axs[2].tick_params(axis='x', labelsize=20)
axs[2].tick_params(axis='y', labelsize=20)
axs[2].set_title("DiffCPS (T=25)", fontsize=38,pad=20)

actions = np.load("diffcps30.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[3])
axs[3].set_xlabel('')
axs[3].set_ylabel('')
axs[3].tick_params(axis='x', labelsize=20)
axs[3].tick_params(axis='y', labelsize=20)
axs[3].set_title("DiffCPS (T=30)", fontsize=38,pad=20)

actions = np.load("diffcps50.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[4])
axs[4].set_xlabel('')
axs[4].set_ylabel('')
axs[4].tick_params(axis='x', labelsize=20)
axs[4].tick_params(axis='y', labelsize=20)
axs[4].set_title("DiffCPS (T=50)", fontsize=38,pad=20)


plt.tight_layout()
plt.savefig("diffcps_res.pdf", format="pdf")

plt.clf()
fig, axs = plt.subplots(1, 5, figsize=(5.5 * 5, 5))
axis_lim = 1.1

data = np.load("toy.npy", allow_pickle=True).item()
actions = data['actions']
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[0])
axs[0].set_xlabel("")
axs[0].set_ylabel("")
axs[0].tick_params(axis='x', labelsize=20)
axs[0].tick_params(axis='y', labelsize=20)
axs[0].set_title("Ground Truth", fontsize=38,pad=20)

# Load the .npy file
actions = np.load("dql15.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[1])
axs[1].set_xlabel('')
axs[1].set_ylabel('')
axs[1].tick_params(axis='x', labelsize=20)
axs[1].tick_params(axis='y', labelsize=20)
axs[1].set_title("DQL (T=15)", fontsize=38,pad=20)

actions = np.load("dql25.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[2])
axs[2].set_xlabel('')
axs[2].set_ylabel('')
axs[2].tick_params(axis='x', labelsize=20)
axs[2].tick_params(axis='y', labelsize=20)
axs[2].set_title("DQL (T=25)", fontsize=38,pad=20)

actions = np.load("dql30.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[3])
axs[3].set_xlabel('')
axs[3].set_ylabel('')
axs[3].tick_params(axis='x', labelsize=20)
axs[3].tick_params(axis='y', labelsize=20)
axs[3].set_title("DQL (T=30)", fontsize=38,pad=20)

actions = np.load("dql50.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[4])
axs[4].set_xlabel('')
axs[4].set_ylabel('')
axs[4].tick_params(axis='x', labelsize=20)
axs[4].tick_params(axis='y', labelsize=20)
axs[4].set_title("DQL (T=50)", fontsize=38,pad=20)


plt.tight_layout()
plt.savefig("dql_res.pdf", format="pdf")

plt.clf()
fig, axs = plt.subplots(1, 5, figsize=(5.5 * 5, 5))
axis_lim = 1.1

data = np.load("toy.npy", allow_pickle=True).item()
actions = data['actions']
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[0])
axs[0].set_xlabel("")
axs[0].set_ylabel("")
axs[0].tick_params(axis='x', labelsize=20)
axs[0].tick_params(axis='y', labelsize=20)
axs[0].set_title("Ground Truth", fontsize=38,pad=20)

# Load the .npy file
actions = np.load("sfbc15.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[1])
axs[1].set_xlabel('')
axs[1].set_ylabel('')
axs[1].tick_params(axis='x', labelsize=20)
axs[1].tick_params(axis='y', labelsize=20)
axs[1].set_title("SfBC (T=15)", fontsize=38,pad=20)

actions = np.load("sfbc25.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[2])
axs[2].set_xlabel('')
axs[2].set_ylabel('')
axs[2].tick_params(axis='x', labelsize=20)
axs[2].tick_params(axis='y', labelsize=20)
axs[2].set_title("SfBC (T=25)", fontsize=38,pad=20)

actions = np.load("sfbc30.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[3])
axs[3].set_xlabel('')
axs[3].set_ylabel('')
axs[3].tick_params(axis='x', labelsize=20)
axs[3].tick_params(axis='y', labelsize=20)
axs[3].set_title("SfBC (T=30)", fontsize=38,pad=20)

actions = np.load("sfbc50.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[4])
axs[4].set_xlabel('')
axs[4].set_ylabel('')
axs[4].tick_params(axis='x', labelsize=20)
axs[4].tick_params(axis='y', labelsize=20)
axs[4].set_title("SfBC (T=50)", fontsize=38,pad=20)


plt.tight_layout()
plt.savefig("sfbc_res.pdf", format="pdf")
